Compare commits

...

99 Commits

Author SHA1 Message Date
Luis Pater
f3c59165d7 Merge branch 'pr-454'
# Conflicts:
#	cmd/server/main.go
#	internal/translator/claude/openai/chat-completions/claude_openai_response.go
2026-03-22 22:52:46 +08:00
Luis Pater
f81acd0760 Merge pull request #2243 from router-for-me/oauth
Improve OAuth callback handling with async prompts
2026-03-20 12:35:44 +08:00
hkfires
636da4c932 refactor(auth): replace manual input handling with AsyncPrompt for callback URLs 2026-03-20 12:24:27 +08:00
hkfires
cccb77b552 fix(auth): avoid blocking oauth callback wait on prompt 2026-03-20 11:48:30 +08:00
Luis Pater
2bd646ad70 refactor: replace sjson.Set usage with sjson.SetBytes to optimize mutable JSON transformations 2026-03-19 17:58:54 +08:00
Luis Pater
56073ded69 Merge pull request #2200 from sususu98/feat/local-model-flag
feat: add -local-model flag to skip remote model catalog fetching
2026-03-18 10:58:07 +08:00
sususu98
9738a53f49 feat: add -local-model flag to skip remote model catalog fetching
When enabled, the server uses only the embedded models.json loaded at
init() and skips registry.StartModelsUpdater(), disabling the initial
remote fetch and 3-hour periodic refresh. The management panel
auto-updater (managementasset.StartAutoUpdater) is unaffected.
2026-03-18 10:48:03 +08:00
Luis Pater
be3f8dbf7e Merge pull request #2187 from Darley-Wey/fix/claude-disable-parallel-tool-calls
fix(claude): honor disable_parallel_tool_use
2026-03-17 21:06:08 +08:00
Darley
9c6c3612a8 fix(claude): read disable_parallel_tool_use from tool_choice 2026-03-17 19:35:41 +08:00
Darley
19e1a4447a fix(claude): honor disable_parallel_tool_use 2026-03-17 19:17:41 +08:00
Luis Pater
7c2ad4cda2 Merge branch 'router-for-me:main' into main 2026-03-17 00:09:43 +08:00
Luis Pater
fb95813fbf Merge pull request #2142 from Muran-prog/fix/strip-uniqueItems-gemini-2123
fix: strip uniqueItems from Gemini function_declarations (#2123)
2026-03-16 20:34:28 +08:00
Luis Pater
db63f9b5d6 Merge pull request #2162 from enieuwy/fix/responses-api-json-valid-check
fix: validate JSON before raw-embedding function call outputs in Responses API
2026-03-16 18:42:31 +08:00
Luis Pater
25f6c4a250 Merge pull request #2158 from sususu98/fix/antigravity-functionresponse-name
fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
2026-03-16 18:39:40 +08:00
enieuwy
b24ae74216 fix: validate JSON before raw-embedding function call outputs in Responses API
gjson.Parse() marks any string starting with { or [ as gjson.JSON type,
even when the content is not valid JSON (e.g. macOS plist format, truncated
tool results). This caused sjson.SetRaw to embed non-JSON content directly
into the Gemini API request payload, producing 400 errors.

Add json.Valid() check before using SetRaw to ensure only actually valid
JSON is embedded raw. Non-JSON content now falls through to sjson.Set
which properly escapes it as a JSON string.

Fixes #2161
2026-03-16 15:29:18 +08:00
Luis Pater
59ad8f40dc Merge pull request #2124 from RGBadmin/feat/auth-list-priority-note
feat(api): expose priority and note in GET /auth-files response
2026-03-16 12:31:11 +08:00
sususu98
ff03dc6a2c fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
The Claude-to-Gemini translator derived function names by splitting
tool_use_id on "-", which produced empty strings for IDs with exactly
2 segments (e.g. toolu_tool-<uuid>). Replace the string-splitting
heuristic with a lookup map built from tool_use blocks during the
main processing loop, with fallback to the raw ID on miss.
2026-03-16 11:18:29 +08:00
Luis Pater
dc7187ca5b fix(websocket): pin only websocket-capable auth IDs and add corresponding test 2026-03-16 09:57:38 +08:00
Luis Pater
b1dcff778c Merge pull request #2141 from Muran-prog/fix/tool-calling-translation-2132
fix: skip empty assistant message in tool call translation (#2132)
2026-03-16 01:42:27 +08:00
Luis Pater
cef2aeeb08 Merge pull request #448 from router-for-me/plus
v6.8.54
2026-03-16 00:37:06 +08:00
Luis Pater
bcd1e8cc34 Merge branch 'main' into plus 2026-03-16 00:36:19 +08:00
Luis Pater
198b3f4a40 chore(ci): update build metadata to use GITHUB_REF_NAME in workflows 2026-03-16 00:30:44 +08:00
Luis Pater
9fee7f488e chore(ci): update GoReleaser config and release workflow to skip validation step 2026-03-16 00:16:25 +08:00
Luis Pater
1b46d39b8b Merge branch 'router-for-me:main' into main 2026-03-15 23:57:47 +08:00
RGBadmin
c1241a98e2 fix(api): restrict fallback note to string-typed JSON values
Only emit note in listAuthFilesFromDisk when the JSON value is actually
a string (gjson.String), matching the synthesizer/buildAuthFileEntry
behavior. Non-string values like numbers or booleans are now ignored
instead of being coerced via gjson.String().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:00:17 +08:00
RGBadmin
8d8f5970ee fix(api): fallback to Metadata for priority/note on uploaded auths
buildAuthFileEntry now falls back to reading priority/note from
auth.Metadata when Attributes lacks them. This covers auths registered
via UploadAuthFile which bypass the synthesizer and only populate
Metadata from the raw JSON.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 17:36:11 +08:00
RGBadmin
f90120f846 fix(api): propagate note to Gemini virtual auths and align priority parsing
- Read note from Attributes (consistent with priority) in buildAuthFileEntry,
  fixing missing note on Gemini multi-project virtual auth cards.
- Propagate note from primary to virtual auths in SynthesizeGeminiVirtualAuths,
  mirroring existing priority propagation.
- Sync note/priority writes to both Metadata and Attributes in PatchAuthFileFields,
  with refactored nil-check to reduce duplication (review feedback).
- Validate priority type in fallback disk-read path instead of coercing all values
  to 0 via gjson.Int(), aligning with the auth-manager code path.
- Add regression tests for note synthesis, virtual-auth note propagation, and
  end-to-end multi-project Gemini note inheritance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 16:47:01 +08:00
Muran-prog
0b94d36c4a test: use exact match for tool name assertion
Address review feedback - drop function.name fallback and
strings.Contains in favor of direct == comparison.
2026-03-14 21:45:28 +02:00
Muran-prog
152c310bb7 test: add uniqueItems stripping test
Covers the fix from the previous commit — verifies uniqueItems is
removed from the schema and moved to the description hint.
2026-03-14 21:22:14 +02:00
Muran-prog
f6bbca35ab fix: strip uniqueItems from Gemini function_declarations (#2123)
Gemini API rejects uniqueItems in tool schemas with 400. Add it to
unsupportedConstraints alongside minItems/maxItems where it belongs.

Same class of fix as #1424 and #1531.
2026-03-14 21:18:06 +02:00
Muran-prog
c8cee6a209 fix: skip empty assistant message in tool call translation (#2132)
When assistant has tool_calls but no text content, the translator
emitted an empty message into the Responses API input array before
function_call items. The API then couldn't match function_call_output
to its function_call by call_id, returning:

  No tool output found for function call ...

Only emit assistant messages that have content parts. Tool-call-only
messages now produce function_call items directly.

Added 9 tests for tool calling translation covering single/parallel
calls, multi-turn conversations, name shortening, empty content
edge cases, and call_id integrity.
2026-03-14 21:01:01 +02:00
Luis Pater
b5701f416b Fixed: #2102
fix(auth): ensure unique auth index for shared API keys across providers and credential identities
2026-03-15 02:48:54 +08:00
Luis Pater
4b1a404fcb Fixed: #1936
feat(translator): add image type handling in ConvertClaudeRequestToGemini
2026-03-15 02:18:28 +08:00
Luis Pater
b93cce5412 Merge pull request #444 from router-for-me/plus
v6.8.53
2026-03-15 01:50:42 +08:00
Luis Pater
c6cb24039d Merge branch 'main' into plus 2026-03-15 01:50:32 +08:00
Luis Pater
5382408489 Merge pull request #441 from GrothKeiran/fix/copilot-token-metadata
fix: persist copilot token metadata
2026-03-15 01:47:41 +08:00
Luis Pater
67669196ed Merge pull request #2131 from HEUDavid/docs/add-who-is-with-us
docs: Add Shadow AI to 'Who is with us?' section
2026-03-15 01:44:46 +08:00
hkfires
58fd9bf964 fix(codex): add 'go' plan_type in registerModelsForAuth 2026-03-14 22:09:14 +08:00
HEUDavid
7b3dfc67bc docs: Add Shadow AI to 'Who is with us?' section 2026-03-14 21:01:07 +08:00
HEUDavid
cdd24052d3 docs: Add Shadow AI to 'Who is with us?' section 2026-03-14 20:53:43 +08:00
Luis Pater
733fd8edab Merge pull request #2111 from qzydustin/main
Fix missing streaming usage tracking for OpenAI-compatible providers
2026-03-14 18:17:08 +08:00
Luis Pater
af27f2b8bc Merge pull request #2110 from router-for-me/codex
feat(service): extend model registration for team and business types
2026-03-14 18:10:01 +08:00
Luis Pater
2e1925d762 Merge pull request #2108 from sususu98/fix/gemini-cli-tool-schema-and-empty-parts
fix(gemini-cli): sanitize tool schemas and filter empty parts
2026-03-14 18:02:52 +08:00
Luis Pater
77254bd074 Merge pull request #2116 from router-for-me/vertex
fix(config): allow vertex keys without base-url
2026-03-14 17:27:48 +08:00
RGBadmin
5b6342e6ac feat(api): expose priority and note fields in GET /auth-files list response
The list endpoint previously omitted priority and note, which are stored
inside each auth file's JSON content. This adds them to both the normal
(auth-manager) and fallback (disk-read) code paths, and extends
PATCH /auth-files/fields to support writing the note field.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 14:47:31 +08:00
GrothKeiran
3960c93d51 refactor: derive copilot metadata from storage 2026-03-13 13:10:01 +00:00
GrothKeiran
339a81b650 fix: persist copilot token metadata 2026-03-13 12:54:43 +00:00
hkfires
560c020477 fix(config): allow vertex keys without base-url 2026-03-13 19:09:26 +08:00
Zhenyu Qi
aec65e3be3 fix(openai_compat): add stream_options.include_usage for streaming usage tracking 2026-03-13 00:48:17 -07:00
hkfires
f44f0702f8 feat(service): extend model registration for team and business types 2026-03-13 14:12:19 +08:00
sususu98
b76b79068f fix(gemini-cli): sanitize tool schemas and filter empty parts
1. Claude translator: add CleanJSONSchemaForGemini() to sanitize tool
   input schemas (removes $schema, anyOf, const, format, etc.) and
   delete eager_input_streaming from tool declarations. Remove fragile
   bytes.Replace for format:"uri" now covered by schema cleaner.

2. Gemini native translator: filter out content entries with empty or
   missing parts arrays to prevent Gemini API 400 error "required
   oneof field 'data' must have one initialized field".

Both fixes align gemini-cli with protections already present in the
antigravity translator.
2026-03-13 12:37:37 +08:00
Luis Pater
34c8ccb961 Fixed: #437
feat(runtime): strip `service_tier` in GitHub Copilot response normalization
2026-03-13 11:50:21 +08:00
Luis Pater
d08e164af3 chore(runtime): remove unused FetchAntigravityModels function from executor 2026-03-13 11:38:44 +08:00
Luis Pater
8178efaeda Merge pull request #439 from router-for-me/plus
v6.8.52
2026-03-13 11:30:25 +08:00
Luis Pater
86d5db472a Merge branch 'main' into plus 2026-03-13 11:28:52 +08:00
Luis Pater
020d36f6e8 Merge pull request #433 from LuxVTZ/feat/gitlab-duo-auth-plus
Add GitLab Duo management OAuth and PAT endpoints
2026-03-13 11:25:19 +08:00
Luis Pater
1db23979e8 Merge pull request #2106 from router-for-me/model
feat(model_registry): enhance model registration and refresh mechanisms
2026-03-13 11:18:51 +08:00
hkfires
c3d5dbe96f feat(model_registry): enhance model registration and refresh mechanisms 2026-03-13 10:56:39 +08:00
Luis Pater
5484489406 chore(ci): update model catalog fetch method in workflows 2026-03-12 11:19:24 +08:00
Luis Pater
0ac52da460 chore(ci): update model catalog fetch method in release workflow 2026-03-12 10:50:46 +08:00
Luis Pater
817cebb321 Merge pull request #2082 from router-for-me/antigravity
Refactor Antigravity model handling and improve logging
2026-03-12 10:39:13 +08:00
Luis Pater
683f3709d6 Merge pull request #2076 from aikins01/fix/backfill-empty-function-response-names
fix: backfill empty functionResponse.name from preceding functionCall
2026-03-12 10:35:44 +08:00
hkfires
dbd42a42b2 fix(model_updater): clarify log message for model refresh failure 2026-03-12 10:32:04 +08:00
hkfires
ec24baf757 feat(fetch_antigravity_models): add command to fetch and save Antigravity model list 2026-03-12 10:21:09 +08:00
hkfires
dea3e74d35 feat(antigravity): refactor model handling and remove unused code 2026-03-12 09:24:45 +08:00
Aikins Laryea
a6c3042e34 refactor: remove redundant bounds checks per code review 2026-03-12 00:12:43 +00:00
Aikins Laryea
861537c9bd fix: backfill empty functionResponse.name from preceding functionCall
when Amp or Claude Code sends functionResponse with an empty name in Gemini
conversation history, the Gemini API rejects the request with 400
"Name cannot be empty". this fix backfills empty names from the
corresponding preceding functionCall parts using positional matching.

covers all three Gemini translator paths:
- gemini/gemini (direct API key)
- antigravity/gemini (OAuth)
- gemini-cli/gemini (Gemini CLI)

also switches fixCLIToolResponse pending group matching from LIFO to
FIFO to correctly handle multiple sequential tool call groups.

fixes #1903
2026-03-12 00:00:38 +00:00
Luis Pater
8c92cb0883 Merge pull request #2056 from lang-911/codex/custom-useragent-request
feat(config/codex): Add Codex header defaults (`user-agent`: override; `beta-features`: default)
2026-03-11 22:56:36 +08:00
Luis Pater
89d7be9525 Merge branch 'dev' into codex/custom-useragent-request 2026-03-11 22:55:50 +08:00
lang-911
2b79d7f22f fix: restore double quotes style in config.example.yaml for consistency and readability 2026-03-11 06:59:26 -07:00
LuxVTZ
2bb686f594 Add GitLab Duo management OAuth and PAT endpoints 2026-03-11 17:58:34 +04:00
lang-911
163fe287ce fix: codex header defaults example 2026-03-11 06:55:03 -07:00
lang-911
70988d387b Add Codex websocket header defaults 2026-03-11 00:34:57 -07:00
Luis Pater
52058a1659 docs: remove GitLab Duo sections from README and README_CN 2026-03-11 11:51:17 +08:00
Luis Pater
df5595a0c9 Merge pull request #428 from LuxVTZ/feat/gitlab-duo-auth-plus
Add GitLab Duo provider support
2026-03-11 11:50:02 +08:00
Luis Pater
ddaa9d2436 Fixed: #2034
feat(proxy): centralize proxy handling with `proxyutil` package and enhance test coverage

- Added `proxyutil` package to simplify proxy handling across the codebase.
- Refactored various components (`executor`, `cliproxy`, `auth`, etc.) to use `proxyutil` for consistent and reusable proxy logic.
- Introduced support for "direct" proxy mode to explicitly bypass all proxies.
- Updated tests to validate proxy behavior (e.g., `direct`, HTTP/HTTPS, and SOCKS5).
- Enhanced YAML configuration documentation for proxy options.
2026-03-11 11:08:02 +08:00
Luis Pater
7b7b258c38 Fixed: #2022
test(translator): add tests for handling Claude system messages as string and array
2026-03-11 10:47:33 +08:00
LuxVTZ
a00f774f5a Add GitLab Duo usage docs 2026-03-10 22:20:40 +04:00
LuxVTZ
9daf1ba8b5 test(gitlab): add duo openai handler smoke 2026-03-10 22:19:36 +04:00
LuxVTZ
76f2359637 test(gitlab): add duo claude handler smoke 2026-03-10 22:19:36 +04:00
LuxVTZ
dcb1c9be8a feat(gitlab): route duo openai via gateway 2026-03-10 22:19:36 +04:00
LuxVTZ
a24f4ace78 feat(gitlab): route duo anthropic via gateway 2026-03-10 22:19:36 +04:00
LuxVTZ
c631df8c3b feat(gitlab): add duo streaming transport 2026-03-10 22:19:36 +04:00
LuxVTZ
54c3eb1b1e Add GitLab Duo auth and executor support 2026-03-10 22:19:36 +04:00
LuxVTZ
bb28cd26ad Add GitLab Duo OAuth and PAT support 2026-03-10 22:18:54 +04:00
Luis Pater
046865461e Merge PR #424 from router-for-me/main 2026-03-10 19:19:29 +08:00
Luis Pater
cf74ed2f0c Merge pull request #2013 from router-for-me/model
Fetch model catalog from network
2026-03-10 19:07:23 +08:00
hkfires
e333fbea3d feat(updater): update StartModelsUpdater to block until models refresh completes 2026-03-10 14:41:58 +08:00
hkfires
efbe36d1d4 feat(updater): change models refresh to one-time fetch on startup 2026-03-10 14:18:54 +08:00
hkfires
8553cfa40e feat(workflows): refresh models catalog in workflows 2026-03-10 14:03:31 +08:00
hkfires
30d5c95b26 feat(registry): refresh model catalog from network 2026-03-10 14:02:54 +08:00
hkfires
d1e3195e6f feat(codex): register models by plan tier 2026-03-10 11:20:37 +08:00
Luis Pater
05a35662ae Merge branch 'router-for-me:main' into main 2026-03-09 23:05:51 +08:00
Luis Pater
ce53d3a287 Fixed: #1997
test(auth-scheduler): add benchmarks and priority-based scheduling improvements

- Added `BenchmarkManagerPickNextMixedPriority500` for mixed-priority performance assessment.
- Updated `pickNextMixed` to prioritize highest ready priority tiers.
- Introduced `highestReadyPriorityLocked` and `pickReadyAtPriorityLocked` for better scheduling logic.
- Added unit test to validate selection of highest priority tiers in mixed provider scenarios.
2026-03-09 22:27:15 +08:00
Luis Pater
4cc99e7449 Merge pull request #1992 from dcrdev/main
System prompt silently dropped when sent as a string
2026-03-09 21:03:15 +08:00
Luis Pater
71773fe032 Merge pull request #1996 from router-for-me/codex/fix-unbounded-websocket-log-buffering
fix: cap websocket body log growth in responses handler
2026-03-09 20:50:38 +08:00
Dominic Robinson
a1e0fa0f39 test(executor): cover string system prompt handling in checkSystemInstructionsWithMode 2026-03-09 12:40:27 +00:00
Supra4E8C
fc2f0b6983 fix: cap websocket body log growth 2026-03-09 17:48:30 +08:00
Dominic Robinson
5c9997cdac fix: Preserve system prompt when sent as a string instead of content block array 2026-03-09 07:38:11 +00:00
155 changed files with 14695 additions and 5307 deletions

View File

@@ -16,6 +16,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
@@ -25,7 +29,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (amd64)
@@ -47,6 +51,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
@@ -56,7 +64,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (arm64)
@@ -90,7 +98,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Create and push multi-arch manifests

View File

@@ -12,6 +12,10 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Go
uses: actions/setup-go@v5
with:

View File

@@ -16,6 +16,10 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:
@@ -23,15 +27,14 @@ jobs:
cache: true
- name: Generate Build Metadata
run: |
VERSION=$(git describe --tags --always --dirty)
echo "VERSION=${VERSION}" >> $GITHUB_ENV
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- uses: goreleaser/goreleaser-action@v4
with:
distribution: goreleaser
version: latest
args: release --clean
args: release --clean --skip=validate
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VERSION: ${{ env.VERSION }}

View File

@@ -1,3 +1,5 @@
version: 2
builds:
- id: "cli-proxy-api-plus"
env:

View File

@@ -6,8 +6,6 @@
所有的第三方供应商支持都由第三方社区维护者提供CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
该 Plus 版本的主线功能与主线功能强制同步。
## 贡献
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
@@ -16,4 +14,4 @@
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -0,0 +1,275 @@
// Command fetch_antigravity_models connects to the Antigravity API using the
// stored auth credentials and saves the dynamically fetched model list to a
// JSON file for inspection or offline use.
//
// Usage:
//
// go run ./cmd/fetch_antigravity_models [flags]
//
// Flags:
//
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
// --output <path> Output JSON file path (default: "antigravity_models.json")
// --pretty Pretty-print the output JSON (default: true)
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
)
func init() {
logging.SetupBaseLogger()
log.SetLevel(log.InfoLevel)
}
// modelOutput wraps the fetched model list with fetch metadata.
type modelOutput struct {
Models []modelEntry `json:"models"`
}
// modelEntry contains only the fields we want to keep for static model definitions.
type modelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
Name string `json:"name"`
Description string `json:"description"`
ContextLength int `json:"context_length,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
func main() {
var authsDir string
var outputPath string
var pretty bool
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
flag.Parse()
// Resolve relative paths against the working directory.
wd, err := os.Getwd()
if err != nil {
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
os.Exit(1)
}
if !filepath.IsAbs(authsDir) {
authsDir = filepath.Join(wd, authsDir)
}
if !filepath.IsAbs(outputPath) {
outputPath = filepath.Join(wd, outputPath)
}
fmt.Printf("Scanning auth files in: %s\n", authsDir)
// Load all auth records from the directory.
fileStore := sdkauth.NewFileTokenStore()
fileStore.SetBaseDir(authsDir)
ctx := context.Background()
auths, err := fileStore.List(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
os.Exit(1)
}
if len(auths) == 0 {
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
os.Exit(1)
}
// Find the first enabled antigravity auth.
var chosen *coreauth.Auth
for _, a := range auths {
if a == nil || a.Disabled {
continue
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
chosen = a
break
}
}
if chosen == nil {
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
os.Exit(1)
}
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
// Fetch models from the upstream Antigravity API.
fmt.Println("Fetching Antigravity model list from upstream...")
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
models := fetchModels(fetchCtx, chosen)
if len(models) == 0 {
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
} else {
fmt.Printf("Fetched %d models.\n", len(models))
}
// Build the output payload.
out := modelOutput{
Models: models,
}
// Marshal to JSON.
var raw []byte
if pretty {
raw, err = json.MarshalIndent(out, "", " ")
} else {
raw, err = json.Marshal(out)
}
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
os.Exit(1)
}
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
os.Exit(1)
}
fmt.Printf("Model list saved to: %s\n", outputPath)
}
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
return nil
}
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
for _, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
if errReq != nil {
continue
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
httpClient.Transport = transport
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
continue
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
httpResp.Body.Close()
if errRead != nil {
continue
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
continue
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
continue
}
var models []modelEntry
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
// Skip internal/experimental models
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
entry := modelEntry{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
DisplayName: displayName,
Name: modelID,
Description: displayName,
}
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
entry.ContextLength = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
entry.MaxCompletionTokens = int(maxOut)
}
models = append(models, entry)
}
return models
}
return nil
}
func metaStringValue(m map[string]interface{}, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
switch val := v.(type) {
case string:
return val
default:
return ""
}
}

View File

@@ -25,6 +25,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
@@ -78,6 +79,8 @@ func main() {
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var gitlabLogin bool
var gitlabTokenLogin bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
@@ -100,6 +103,7 @@ func main() {
var standalone bool
var noIncognito bool
var useIncognito bool
var localModel bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -110,6 +114,8 @@ func main() {
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&gitlabLogin, "gitlab-login", false, "Login to GitLab Duo using OAuth")
flag.BoolVar(&gitlabTokenLogin, "gitlab-token-login", false, "Login to GitLab Duo using a personal access token")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
@@ -132,6 +138,7 @@ func main() {
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -526,6 +533,10 @@ func main() {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if gitlabLogin {
cmd.DoGitLabLogin(cfg, options)
} else if gitlabTokenLogin {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if kiroLogin {
@@ -569,10 +580,16 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
if localModel && (!tuiMode || standalone) {
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
}
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if !localModel {
registry.StartModelsUpdater(context.Background())
}
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
@@ -643,15 +660,18 @@ func main() {
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if !localModel {
registry.StartModelsUpdater(context.Background())
}
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
cmd.StartService(cfg, configFilePath, password)
}
}
}

View File

@@ -68,7 +68,8 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ''
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
@@ -115,6 +116,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
@@ -133,6 +135,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
@@ -151,6 +154,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
@@ -178,6 +182,14 @@ nonstream-keepalive-interval: 0
# runtime-version: "v24.3.0"
# timeout: "600"
# Default headers for Codex OAuth model requests.
# These are used only for file-backed/OAuth Codex requests when the client
# does not send the header. `user-agent` applies to HTTP and websocket requests;
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
# codex-header-defaults:
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
# beta-features: "multi_agent"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
#kiro:
@@ -215,6 +227,7 @@ nonstream-keepalive-interval: 0
# api-key-entries:
# - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
@@ -231,12 +244,13 @@ nonstream-keepalive-interval: 0
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
# vertex-api-key:
# - api-key: "vk-123..." # x-goog-api-key header
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names

115
docs/gitlab-duo.md Normal file
View File

@@ -0,0 +1,115 @@
# GitLab Duo guide
CLIProxyAPI can now use GitLab Duo as a first-class provider instead of treating it as a plain text wrapper.
It supports:
- OAuth login
- personal access token login
- automatic refresh of GitLab `direct_access` metadata
- dynamic model discovery from GitLab metadata
- native GitLab AI gateway routing for Anthropic and OpenAI/Codex managed models
- Claude-compatible and OpenAI-compatible downstream APIs
## What this means
If GitLab Duo returns an Anthropic-managed model, CLIProxyAPI routes requests through the GitLab AI gateway Anthropic proxy and uses the existing Claude executor path.
If GitLab Duo returns an OpenAI-managed model, CLIProxyAPI routes requests through the GitLab AI gateway OpenAI proxy and uses the existing Codex/OpenAI executor path.
That gives GitLab Duo much closer runtime behavior to the built-in `codex` provider:
- Claude-compatible clients can use GitLab Duo models through `/v1/messages`
- OpenAI-compatible clients can use GitLab Duo models through `/v1/chat/completions`
- OpenAI Responses clients can use GitLab Duo models through `/v1/responses`
The model list is not hardcoded. CLIProxyAPI reads the current model metadata from GitLab `direct_access` and registers:
- a stable alias: `gitlab-duo`
- any discovered managed model names, such as `claude-sonnet-4-5` or `gpt-5-codex`
## Login
OAuth login:
```bash
./CLIProxyAPI -gitlab-login
```
PAT login:
```bash
./CLIProxyAPI -gitlab-token-login
```
You can also provide inputs through environment variables:
```bash
export GITLAB_BASE_URL=https://gitlab.com
export GITLAB_OAUTH_CLIENT_ID=your-client-id
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
```
Notes:
- OAuth requires a GitLab OAuth application.
- PAT login requires a personal access token that can call the GitLab APIs used by Duo. In practice, `api` scope is the safe baseline.
- Self-managed GitLab instances are supported through `GITLAB_BASE_URL`.
## Using the models
After login, start CLIProxyAPI normally and point your client at the local proxy.
You can select:
- `gitlab-duo` to use the current Duo-managed model for that account
- the discovered provider model name if you want to pin it explicitly
Examples:
```bash
curl http://127.0.0.1:8080/v1/models
```
```bash
curl http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "gitlab-duo",
"messages": [
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
]
}'
```
If the GitLab account is currently mapped to an Anthropic model, Claude-compatible clients can use the same account through the Claude handler path. If the account is currently mapped to an OpenAI/Codex model, OpenAI-compatible clients can use `/v1/chat/completions` or `/v1/responses`.
## How model freshness works
CLIProxyAPI does not ship a fixed GitLab Duo model catalog.
Instead, it refreshes GitLab `direct_access` metadata and uses the returned `model_details` and any discovered model list entries to keep the local registry aligned with the current GitLab-managed model assignment.
This matches GitLab's current public contract better than hardcoding model names.
## Current scope
The GitLab Duo provider now has:
- OAuth and PAT auth flows
- runtime refresh of Duo gateway credentials
- native Anthropic gateway routing
- native OpenAI/Codex gateway routing
- handler-level smoke tests for Claude-compatible and OpenAI-compatible paths
Still out of scope today:
- websocket or session-specific parity beyond the current HTTP APIs
- GitLab-specific IDE features that are not exposed through the public gateway contract
## References
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
- GitLab Agent Assistant and managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
- GitLab Duo model selection: https://docs.gitlab.com/user/gitlab_duo/model_selection/

115
docs/gitlab-duo_CN.md Normal file
View File

@@ -0,0 +1,115 @@
# GitLab Duo 使用说明
CLIProxyAPI 现在可以把 GitLab Duo 当作一等 Provider 来使用,而不是仅仅把它当成简单的文本补全封装。
当前支持:
- OAuth 登录
- personal access token 登录
- 自动刷新 GitLab `direct_access` 元数据
- 根据 GitLab 返回的元数据动态发现模型
- 针对 Anthropic 和 OpenAI/Codex 托管模型的 GitLab AI gateway 原生路由
- Claude 兼容与 OpenAI 兼容下游 API
## 这意味着什么
如果 GitLab Duo 返回的是 Anthropic 托管模型CLIProxyAPI 会通过 GitLab AI gateway 的 Anthropic 代理转发,并复用现有的 Claude executor 路径。
如果 GitLab Duo 返回的是 OpenAI 托管模型CLIProxyAPI 会通过 GitLab AI gateway 的 OpenAI 代理转发,并复用现有的 Codex/OpenAI executor 路径。
这让 GitLab Duo 的运行时行为更接近内置的 `codex` Provider
- Claude 兼容客户端可以通过 `/v1/messages` 使用 GitLab Duo 模型
- OpenAI 兼容客户端可以通过 `/v1/chat/completions` 使用 GitLab Duo 模型
- OpenAI Responses 客户端可以通过 `/v1/responses` 使用 GitLab Duo 模型
模型列表不是硬编码的。CLIProxyAPI 会从 GitLab `direct_access` 中读取当前模型元数据,并注册:
- 一个稳定别名:`gitlab-duo`
- GitLab 当前发现到的托管模型名,例如 `claude-sonnet-4-5``gpt-5-codex`
## 登录
OAuth 登录:
```bash
./CLIProxyAPI -gitlab-login
```
PAT 登录:
```bash
./CLIProxyAPI -gitlab-token-login
```
也可以通过环境变量提供输入:
```bash
export GITLAB_BASE_URL=https://gitlab.com
export GITLAB_OAUTH_CLIENT_ID=your-client-id
export GITLAB_OAUTH_CLIENT_SECRET=your-client-secret
export GITLAB_PERSONAL_ACCESS_TOKEN=glpat-...
```
说明:
- OAuth 方式需要一个 GitLab OAuth application。
- PAT 登录需要一个能够调用 GitLab Duo 相关 API 的 personal access token。实践上`api` scope 是最稳妥的基线。
- 自建 GitLab 实例可以通过 `GITLAB_BASE_URL` 接入。
## 如何使用模型
登录完成后,正常启动 CLIProxyAPI并让客户端连接到本地代理。
你可以选择:
- `gitlab-duo`,始终使用该账号当前的 Duo 托管模型
- GitLab 当前发现到的 provider 模型名,如果你想显式固定模型
示例:
```bash
curl http://127.0.0.1:8080/v1/models
```
```bash
curl http://127.0.0.1:8080/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "gitlab-duo",
"messages": [
{"role": "user", "content": "Write a Go HTTP middleware for request IDs."}
]
}'
```
如果该 GitLab 账号当前绑定的是 Anthropic 模型Claude 兼容客户端可以通过 Claude handler 路径直接使用它。如果当前绑定的是 OpenAI/Codex 模型OpenAI 兼容客户端可以通过 `/v1/chat/completions``/v1/responses` 使用它。
## 模型如何保持最新
CLIProxyAPI 不内置固定的 GitLab Duo 模型清单。
它会刷新 GitLab `direct_access` 元数据,并使用返回的 `model_details` 以及可能存在的模型列表字段,让本地 registry 尽量与 GitLab 当前分配的托管模型保持一致。
这比硬编码模型名更符合 GitLab 当前公开 API 的实际契约。
## 当前覆盖范围
GitLab Duo Provider 目前已经具备:
- OAuth 和 PAT 登录流程
- Duo gateway 凭据的运行时刷新
- Anthropic gateway 原生路由
- OpenAI/Codex gateway 原生路由
- Claude 兼容和 OpenAI 兼容路径的 handler 级 smoke 测试
当前仍未覆盖:
- websocket 或 session 级别的完全对齐
- GitLab 公开 gateway 契约之外的 IDE 专有能力
## 参考资料
- GitLab Code Suggestions API: https://docs.gitlab.com/api/code_suggestions/
- GitLab Agent Assistant 与 managed credentials: https://docs.gitlab.com/user/duo_agent_platform/agent_assistant/
- GitLab Duo 模型选择: https://docs.gitlab.com/user/gitlab_duo/model_selection/

View File

@@ -52,11 +52,11 @@ func init() {
sdktr.Register(fOpenAI, fMyProv,
func(model string, raw []byte, stream bool) []byte { return raw },
sdktr.ResponseTransform{
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
return []string{string(raw)}
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
return [][]byte{raw}
},
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
return string(raw)
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
return raw
},
},
)

View File

@@ -0,0 +1,278 @@
# Plan: GitLab Duo Codex Parity
**Generated**: 2026-03-10
**Estimated Complexity**: High
## Overview
Bring GitLab Duo support from the current "auth + basic executor" stage to the same practical level as `codex` inside `CLIProxyAPI`: a user logs in once, points external clients such as Claude Code at `CLIProxyAPI`, selects GitLab Duo-backed models, and gets stable streaming, multi-turn behavior, tool calling compatibility, and predictable model routing without manual provider-specific workarounds.
The core architectural shift is to stop treating GitLab Duo as only two REST wrappers (`/api/v4/chat/completions` and `/api/v4/code_suggestions/completions`) and instead use GitLab's `direct_access` contract as the primary runtime entrypoint wherever possible. Official GitLab docs confirm that `direct_access` returns AI gateway connection details, headers, token, and expiry; that contract is the closest path to codex-like provider behavior.
## Prerequisites
- Official GitLab Duo API references confirmed during implementation:
- `POST /api/v4/code_suggestions/direct_access`
- `POST /api/v4/code_suggestions/completions`
- `POST /api/v4/chat/completions`
- Access to at least one real GitLab Duo account for manual verification.
- One downstream client target for acceptance testing:
- Claude Code against Claude-compatible endpoint
- OpenAI-compatible client against `/v1/chat/completions` and `/v1/responses`
- Existing PR branch as starting point:
- `feat/gitlab-duo-auth`
- PR [#2028](https://github.com/router-for-me/CLIProxyAPI/pull/2028)
## Definition Of Done
- GitLab Duo models can be used via `CLIProxyAPI` from the same client surfaces that already work for `codex`.
- Upstream streaming is real passthrough or faithful chunked forwarding, not synthetic whole-response replay.
- Tool/function calling survives translation layers without dropping fields or corrupting names.
- Multi-turn and session semantics are stable across `chat/completions`, `responses`, and Claude-compatible routes.
- Model exposure stays current from GitLab metadata or gateway discovery without hardcoded stale model tables.
- `go test ./...` stays green and at least one real manual end-to-end client flow is documented.
## Sprint 1: Contract And Gap Closure
**Goal**: Replace assumptions with a hard compatibility contract between current `codex` behavior and what GitLab Duo can actually support.
**Demo/Validation**:
- Written matrix showing `codex` features vs current GitLab Duo behavior.
- One checked-in developer note or test fixture for real GitLab Duo payload examples.
### Task 1.1: Freeze Codex Parity Checklist
- **Location**: [internal/runtime/executor/codex_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_executor.go), [internal/runtime/executor/codex_websockets_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/codex_websockets_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go)
- **Description**: Produce a concrete feature matrix for `codex`: HTTP execute, SSE execute, `/v1/responses`, websocket downstream path, tool calling, request IDs, session close semantics, and model registration behavior.
- **Dependencies**: None
- **Acceptance Criteria**:
- A checklist exists in repo docs or issue notes.
- Each capability is marked `required`, `optional`, or `not possible` for GitLab Duo.
- **Validation**:
- Review against current `codex` code paths.
### Task 1.2: Lock GitLab Duo Runtime Contract
- **Location**: [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Validate the exact upstream contract we can rely on:
- `direct_access` fields and refresh cadence
- whether AI gateway path is usable directly
- when `chat/completions` is available vs when fallback is required
- what streaming shape is returned by `code_suggestions/completions?stream=true`
- **Dependencies**: Task 1.1
- **Acceptance Criteria**:
- GitLab transport decision is explicit: `gateway-first`, `REST-first`, or `hybrid`.
- Unknown areas are isolated behind feature flags, not spread across executor logic.
- **Validation**:
- Official docs + captured real responses from a Duo account.
### Task 1.3: Define Client-Facing Compatibility Targets
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md), [gitlab-duo-codex-parity-plan.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/gitlab-duo-codex-parity-plan.md)
- **Description**: Define exactly which external flows must work to call GitLab Duo support "like codex".
- **Dependencies**: Task 1.2
- **Acceptance Criteria**:
- Required surfaces are listed:
- Claude-compatible route
- OpenAI `chat/completions`
- OpenAI `responses`
- optional downstream websocket path
- Non-goals are explicit if GitLab upstream cannot support them.
- **Validation**:
- Maintainer review of stated scope.
## Sprint 2: Primary Transport Parity
**Goal**: Move GitLab Duo execution onto a transport that supports codex-like runtime behavior.
**Demo/Validation**:
- A GitLab Duo model works over real streaming through `/v1/chat/completions`.
- No synthetic "collect full body then fake stream" path remains on the primary flow.
### Task 2.1: Refactor GitLab Executor Into Strategy Layers
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Split current executor into explicit strategies:
- auth refresh/direct access refresh
- gateway transport
- GitLab REST fallback transport
- downstream translation helpers
- **Dependencies**: Sprint 1
- **Acceptance Criteria**:
- Executor no longer mixes discovery, refresh, fallback selection, and response synthesis in one path.
- Transport choice is testable in isolation.
- **Validation**:
- Unit tests for strategy selection and fallback boundaries.
### Task 2.2: Implement Real Streaming Path
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go)
- **Description**: Replace synthetic streaming with true upstream incremental forwarding:
- use gateway stream if available
- otherwise consume GitLab Code Suggestions streaming response and map chunks incrementally
- **Dependencies**: Task 2.1
- **Acceptance Criteria**:
- `ExecuteStream` emits chunks before upstream completion.
- error handling preserves status and early failure semantics.
- **Validation**:
- tests with chunked upstream server
- manual curl check against `/v1/chat/completions` with `stream=true`
### Task 2.3: Preserve Upstream Auth And Headers Correctly
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/auth/gitlab/gitlab.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/auth/gitlab/gitlab.go)
- **Description**: Use `direct_access` connection details as first-class transport state:
- gateway token
- expiry
- mandatory forwarded headers
- model metadata
- **Dependencies**: Task 2.1
- **Acceptance Criteria**:
- executor stops ignoring gateway headers/token when transport requires them
- refresh logic never over-fetches `direct_access`
- **Validation**:
- tests verifying propagated headers and refresh interval behavior
## Sprint 3: Request/Response Semantics Parity
**Goal**: Make GitLab Duo behave correctly under the same request shapes that current `codex` consumers send.
**Demo/Validation**:
- OpenAI and Claude-compatible clients can do non-streaming and streaming conversations without losing structure.
### Task 3.1: Normalize Multi-Turn Message Mapping
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/translator](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/translator)
- **Description**: Replace the current "flatten prompt into one instruction" behavior with stable multi-turn mapping:
- preserve system context
- preserve user/assistant ordering
- maintain bounded context truncation
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- multi-turn requests are not collapsed into a lossy single string unless fallback mode explicitly requires it
- truncation policy is deterministic and tested
- **Validation**:
- golden tests for request mapping
### Task 3.2: Tool Calling Compatibility Layer
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go)
- **Description**: Decide and implement one of two paths:
- native pass-through if GitLab gateway supports tool/function structures
- strict downgrade path with explicit unsupported errors instead of silent field loss
- **Dependencies**: Task 3.1
- **Acceptance Criteria**:
- tool-related fields are either preserved correctly or rejected explicitly
- no silent corruption of tool names, tool calls, or tool results
- **Validation**:
- table-driven tests for tool payloads
- one manual client scenario using tools
### Task 3.3: Token Counting And Usage Reporting Fidelity
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [internal/runtime/executor/usage_helpers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/usage_helpers.go)
- **Description**: Improve token/usage reporting so GitLab models behave like first-class providers in logs and scheduling.
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- `CountTokens` uses the closest supported estimation path
- usage logging distinguishes prompt vs completion when possible
- **Validation**:
- unit tests for token estimation outputs
## Sprint 4: Responses And Session Parity
**Goal**: Reach codex-level support for OpenAI Responses clients and long-lived sessions where GitLab upstream permits it.
**Demo/Validation**:
- `/v1/responses` works with GitLab Duo in a realistic client flow.
- If websocket parity is not possible, the code explicitly declines it and keeps HTTP paths stable.
### Task 4.1: Make GitLab Compatible With `/v1/responses`
- **Location**: [sdk/api/handlers/openai/openai_responses_handlers.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_handlers.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Ensure GitLab transport can safely back the Responses API path, including compact responses if applicable.
- **Dependencies**: Sprint 3
- **Acceptance Criteria**:
- GitLab Duo can be selected behind `/v1/responses`
- response IDs and follow-up semantics are defined
- **Validation**:
- handler tests analogous to codex/openai responses tests
### Task 4.2: Evaluate Downstream Websocket Parity
- **Location**: [sdk/api/handlers/openai/openai_responses_websocket.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai/openai_responses_websocket.go), [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go)
- **Description**: Decide whether GitLab Duo can support downstream websocket sessions like codex:
- if yes, add session-aware execution path
- if no, mark GitLab auth as websocket-ineligible and keep HTTP routes first-class
- **Dependencies**: Task 4.1
- **Acceptance Criteria**:
- websocket behavior is explicit, not accidental
- no route claims websocket support when the upstream cannot honor it
- **Validation**:
- websocket handler tests or explicit capability tests
### Task 4.3: Add Session Cleanup And Failure Recovery Semantics
- **Location**: [internal/runtime/executor/gitlab_executor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor.go), [sdk/cliproxy/auth/conductor.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/auth/conductor.go)
- **Description**: Add codex-like session cleanup, retry boundaries, and model suspension/resume behavior for GitLab failures and quota events.
- **Dependencies**: Sprint 2
- **Acceptance Criteria**:
- auth/model cooldown behavior is predictable on GitLab 4xx/5xx/quota responses
- executor cleans up per-session resources if any are introduced
- **Validation**:
- tests for quota and retry behavior
## Sprint 5: Client UX, Model UX, And Manual E2E
**Goal**: Make GitLab Duo feel like a normal built-in provider to operators and downstream clients.
**Demo/Validation**:
- A documented setup exists for "login once, point Claude Code at CLIProxyAPI, use GitLab Duo-backed model".
### Task 5.1: Model Alias And Provider UX Cleanup
- **Location**: [sdk/cliproxy/service.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/cliproxy/service.go), [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
- **Description**: Normalize what users see:
- stable alias such as `gitlab-duo`
- discovered upstream model names
- optional prefix behavior
- account labels that clearly distinguish OAuth vs PAT
- **Dependencies**: Sprint 3
- **Acceptance Criteria**:
- users can select a stable GitLab alias even when upstream model changes
- dynamic model discovery does not cause confusing model churn
- **Validation**:
- registry tests and manual `/v1/models` inspection
### Task 5.2: Add Real End-To-End Acceptance Tests
- **Location**: [internal/runtime/executor/gitlab_executor_test.go](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/internal/runtime/executor/gitlab_executor_test.go), [sdk/api/handlers/openai](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/sdk/api/handlers/openai)
- **Description**: Add higher-level tests covering the actual proxy surfaces:
- OpenAI `chat/completions`
- OpenAI `responses`
- Claude-compatible request path if GitLab is routed there
- **Dependencies**: Sprint 4
- **Acceptance Criteria**:
- tests fail if streaming regresses into synthetic buffering again
- tests cover at least one tool-related request and one multi-turn request
- **Validation**:
- `go test ./...`
### Task 5.3: Publish Operator Documentation
- **Location**: [README.md](/home/luxvtz/projects/cliproxyapi/CLIProxyAPI/README.md)
- **Description**: Document:
- OAuth setup requirements
- PAT requirements
- current capability matrix
- known limitations if websocket/tool parity is partial
- **Dependencies**: Sprint 5.1
- **Acceptance Criteria**:
- setup instructions are enough for a new user to reproduce the GitLab Duo flow
- limitations are explicit
- **Validation**:
- dry-run docs review from a clean environment
## Testing Strategy
- Keep `go test ./...` green after every committable task.
- Add table-driven tests first for request mapping, refresh behavior, and dynamic model registration.
- Add transport tests with `httptest.Server` for:
- real chunked streaming
- header propagation from `direct_access`
- upstream fallback rules
- Add at least one manual acceptance checklist:
- login via OAuth
- login via PAT
- list models
- run one streaming prompt via OpenAI route
- run one prompt from the target downstream client
## Potential Risks & Gotchas
- GitLab public docs expose `direct_access`, but do not fully document every possible AI gateway path. We should isolate any empirically discovered gateway assumptions behind one transport layer and feature flags.
- `chat/completions` availability differs by GitLab offering and version. The executor must not assume it always exists.
- Code Suggestions is completion-oriented; lossy mapping from rich chat/tool payloads will make GitLab Duo feel worse than codex unless explicitly handled.
- Synthetic streaming is not good enough for codex parity and will cause regressions in interactive clients.
- Dynamic model discovery can create unstable UX if the stable alias and discovered model IDs are not separated cleanly.
- PAT auth may validate successfully while still lacking effective Duo permissions. Error reporting must surface this explicitly.
## Rollback Plan
- Keep the current basic GitLab executor behind a fallback mode until the new transport path is stable.
- If parity work destabilizes existing providers, revert only GitLab-specific executor changes and leave auth support intact.
- Preserve the stable `gitlab-duo` alias so rollback does not break client configuration.

View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@@ -14,13 +13,12 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
return transport
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).

View File

@@ -2,172 +2,112 @@ package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if httpTransport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
}
}
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel()
manager := coreauth.NewManager(nil, nil, nil)
geminiAuth := &coreauth.Auth{
ID: "gemini:apikey:123",
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
},
}
compatAuth := &coreauth.Auth{
ID: "openai-compatibility:bohe:456",
Provider: "bohe",
Label: "bohe",
Attributes: map[string]string{
"api_key": "shared-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
}
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
t.Fatalf("register gemini auth: %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
t.Fatalf("register compat auth: %v", errRegister)
}
geminiIndex := geminiAuth.EnsureIndex()
compatIndex := compatAuth.EnsureIndex()
if geminiIndex == compatIndex {
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
gotGemini := h.authByIndex(geminiIndex)
if gotGemini == nil {
t.Fatal("expected gemini auth by index")
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
if gotGemini.ID != geminiAuth.ID {
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
gotCompat := h.authByIndex(compatIndex)
if gotCompat == nil {
t.Fatal("expected compat auth by index")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
if gotCompat.ID != compatAuth.ID {
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
}
}

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
@@ -54,6 +55,8 @@ const (
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
gitLabLoginModeOAuth = "oauth"
gitLabLoginModePAT = "pat"
)
type callbackForwarder struct {
@@ -338,6 +341,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
emailValue := gjson.GetBytes(data, "email").String()
fileData["type"] = typeValue
fileData["email"] = emailValue
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
switch pv.Type {
case gjson.Number:
fileData["priority"] = int(pv.Int())
case gjson.String:
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
fileData["priority"] = parsed
}
}
}
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
fileData["note"] = trimmed
}
}
}
files = append(files, fileData)
@@ -421,6 +439,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
if parsed, err := strconv.Atoi(p); err == nil {
entry["priority"] = parsed
}
} else if auth.Metadata != nil {
if rawPriority, ok := auth.Metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
entry["priority"] = int(v)
case int:
entry["priority"] = v
case string:
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
entry["priority"] = parsed
}
}
}
}
// Expose note from Attributes (set by synthesizer from JSON "note" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
entry["note"] = note
} else if auth.Metadata != nil {
if rawNote, ok := auth.Metadata["note"].(string); ok {
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
entry["note"] = trimmed
}
}
}
return entry
}
@@ -845,7 +894,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
@@ -857,6 +906,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
Note *string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
@@ -899,14 +949,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
if req.Priority != nil {
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
delete(targetAuth.Attributes, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
}
}
if req.Note != nil {
trimmedNote := strings.TrimSpace(*req.Note)
if trimmedNote == "" {
delete(targetAuth.Metadata, "note")
delete(targetAuth.Attributes, "note")
} else {
targetAuth.Metadata["note"] = trimmedNote
targetAuth.Attributes["note"] = trimmedNote
}
}
changed = true
}
@@ -999,6 +1067,165 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
return store.Save(ctx, record)
}
func gitLabBaseURLFromRequest(c *gin.Context) string {
if c != nil {
if raw := strings.TrimSpace(c.Query("base_url")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
}
if raw := strings.TrimSpace(os.Getenv("GITLAB_BASE_URL")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
return gitlabauth.DefaultBaseURL
}
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
metadata := map[string]any{
"type": "gitlab",
"auth_method": strings.TrimSpace(mode),
"base_url": gitlabauth.NormalizeBaseURL(baseURL),
"last_refresh": time.Now().UTC().Format(time.RFC3339),
"refresh_interval_seconds": 240,
}
if tokenResp != nil {
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
mergeGitLabDirectAccessMetadata(metadata, direct)
return metadata
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
now := time.Now().UTC()
if ttl := expiry.Sub(now); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func primaryGitLabEmail(user *gitlabauth.User) string {
if user == nil {
return ""
}
if value := strings.TrimSpace(user.Email); value != "" {
return value
}
return strings.TrimSpace(user.PublicEmail)
}
func gitLabAccountIdentifier(user *gitlabauth.User) string {
if user == nil {
return "user"
}
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return "user"
}
func sanitizeGitLabFileName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return "user"
}
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
builder.WriteRune(r)
lastDash = false
case r >= '0' && r <= '9':
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
result := strings.Trim(builder.String(), "-")
if result == "" {
return "user"
}
return result
}
func maskGitLabToken(token string) string {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return ""
}
if len(trimmed) <= 8 {
return trimmed
}
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -1549,6 +1776,263 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing GitLab Duo authentication...")
baseURL := gitLabBaseURLFromRequest(c)
clientID := strings.TrimSpace(c.Query("client_id"))
clientSecret := strings.TrimSpace(c.Query("client_secret"))
if clientID == "" {
clientID = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_ID"))
}
if clientSecret == "" {
clientSecret = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_SECRET"))
}
if clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "gitlab client_id is required"})
return
}
pkceCodes, err := gitlabauth.GeneratePKCECodes()
if err != nil {
log.Errorf("Failed to generate GitLab PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, err := misc.GenerateRandomState()
if err != nil {
log.Errorf("Failed to generate GitLab state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
redirectURI := gitlabauth.RedirectURL(gitlabauth.DefaultCallbackPort)
authClient := gitlabauth.NewAuthClient(h.cfg)
authURL, err := authClient.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
if err != nil {
log.Errorf("Failed to generate GitLab authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "gitlab")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/gitlab/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute gitlab callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(gitlabauth.DefaultCallbackPort, "gitlab", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gitlab callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(gitlabauth.DefaultCallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gitlab-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if !IsOAuthSessionPending(state, "gitlab") {
return
}
if time.Now().After(deadline) {
log.Error("gitlab oauth flow timed out")
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return
}
if data, errRead := os.ReadFile(waitFile); errRead == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
SetOAuthSessionError(state, errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != state {
SetOAuthSessionError(state, "State code error")
return
}
code = strings.TrimSpace(payload["code"])
if code == "" {
SetOAuthSessionError(state, "Authorization code missing")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
tokenResp, errExchange := authClient.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, code, pkceCodes.CodeVerifier)
if errExchange != nil {
log.Errorf("Failed to exchange GitLab authorization code: %v", errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
user, errUser := authClient.GetCurrentUser(ctx, baseURL, tokenResp.AccessToken)
if errUser != nil {
log.Errorf("Failed to fetch GitLab user profile: %v", errUser)
SetOAuthSessionError(state, "Failed to fetch account profile")
return
}
direct, errDirect := authClient.FetchDirectAccess(ctx, baseURL, tokenResp.AccessToken)
if errDirect != nil {
log.Errorf("Failed to fetch GitLab direct access metadata: %v", errDirect)
SetOAuthSessionError(state, "Failed to fetch GitLab Duo access")
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save GitLab auth record: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("GitLab Duo authentication successful. Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gitlab")
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabPATToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
var payload struct {
BaseURL string `json:"base_url"`
PersonalAccessToken string `json:"personal_access_token"`
Token string `json:"token"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
return
}
baseURL := gitlabauth.NormalizeBaseURL(strings.TrimSpace(payload.BaseURL))
if baseURL == "" {
baseURL = gitLabBaseURLFromRequest(nil)
}
pat := strings.TrimSpace(payload.PersonalAccessToken)
if pat == "" {
pat = strings.TrimSpace(payload.Token)
}
if pat == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "personal_access_token is required"})
return
}
authClient := gitlabauth.NewAuthClient(h.cfg)
user, err := authClient.GetCurrentUser(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
patSelf, err := authClient.GetPersonalAccessTokenSelf(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
direct, err := authClient.FetchDirectAccess(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
metadata["auth_kind"] = "personal_access_token"
metadata["personal_access_token"] = pat
metadata["token_preview"] = maskGitLabToken(pat)
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
if patSelf != nil {
if name := strings.TrimSpace(patSelf.Name); name != "" {
metadata["pat_name"] = name
}
if len(patSelf.Scopes) > 0 {
metadata["pat_scopes"] = append([]string(nil), patSelf.Scopes...)
}
}
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier + " (PAT)",
Metadata: metadata,
}
savedPath, err := h.saveTokenRecord(ctx, record)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
return
}
response := gin.H{
"status": "ok",
"saved_path": savedPath,
"username": strings.TrimSpace(user.Username),
"email": primaryGitLabEmail(user),
"token_label": identifier,
}
if direct != nil && direct.ModelDetails != nil {
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
response["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
response["model_name"] = model
}
}
fmt.Printf("GitLab Duo PAT authentication successful. Token saved to %s\n", savedPath)
c.JSON(http.StatusOK, response)
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -2019,17 +2503,20 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
if label == "" {
label = username
}
metadata, errMeta := copilotTokenMetadata(tokenStorage)
if errMeta != nil {
log.Errorf("Failed to build token metadata: %v", errMeta)
SetOAuthSessionError(state, "Failed to build token metadata")
return
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
@@ -2054,6 +2541,21 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
})
}
func copilotTokenMetadata(storage *copilot.CopilotTokenStorage) (map[string]any, error) {
if storage == nil {
return nil, fmt.Errorf("token storage is nil")
}
payload, errMarshal := json.Marshal(storage)
if errMarshal != nil {
return nil, fmt.Errorf("marshal token storage: %w", errMarshal)
}
metadata := make(map[string]any)
if errUnmarshal := json.Unmarshal(payload, &metadata); errUnmarshal != nil {
return nil, fmt.Errorf("unmarshal token storage: %w", errUnmarshal)
}
return metadata, nil
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()

View File

@@ -0,0 +1,164 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
}
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v4/user":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 42,
"username": "gitlab-user",
"name": "GitLab User",
"email": "gitlab@example.com",
})
case "/api/v4/personal_access_tokens/self":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 7,
"name": "management-center",
"scopes": []string{"api", "read_user"},
"user_id": 42,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1893456000,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
store := &memoryAuthStore{}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
h.tokenStore = store
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.RequestGitLabPATToken(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if got := resp["status"]; got != "ok" {
t.Fatalf("status = %#v, want ok", got)
}
if got := resp["model_provider"]; got != "anthropic" {
t.Fatalf("model_provider = %#v, want anthropic", got)
}
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
}
store.mu.Lock()
defer store.mu.Unlock()
if len(store.items) != 1 {
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
}
var saved *coreauth.Auth
for _, item := range store.items {
saved = item
}
if saved == nil {
t.Fatal("expected saved auth record")
}
if saved.Provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", saved.Provider)
}
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
}
if got := saved.Metadata["model_provider"]; got != "anthropic" {
t.Fatalf("saved model_provider = %#v, want anthropic", got)
}
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
}
}
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
state := "gitlab-state-123"
RegisterOAuthSession(state, "gitlab")
t.Cleanup(func() { CompleteOAuthSession(state) })
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.PostOAuthCallback(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
data, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("read callback file: %v", err)
}
var payload map[string]string
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("decode callback payload: %v", err)
}
if got := payload["code"]; got != "test-code" {
t.Fatalf("callback code = %q, want test-code", got)
}
if got := payload["state"]; got != state {
t.Fatalf("callback state = %q, want %q", got, state)
}
}
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
provider, err := NormalizeOAuthProvider("gitlab")
if err != nil {
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
}
if provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", provider)
}
}

View File

@@ -509,8 +509,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
if arr[i].APIKey == "" {
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
return
}
}
h.cfg.VertexCompatAPIKey = arr
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}

View File

@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "anthropic", nil
case "codex", "openai":
return "codex", nil
case "gitlab":
return "gitlab", nil
case "gemini", "google":
return "gemini", nil
case "iflow", "i-flow":

View File

@@ -0,0 +1,49 @@
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}

View File

@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)

View File

@@ -4,12 +4,12 @@ package claude
import (
"net/http"
"net/url"
"strings"
"sync"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" {
proxyURL, err := url.Parse(cfg.ProxyURL)
if err != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
} else {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
if cfg != nil {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}

View File

@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
@@ -20,9 +18,9 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
} else if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
var err error
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: ClientID,
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
defer manualPromptTimer.Stop()
}
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback:
for {
select {
@@ -348,13 +329,14 @@ waitForCallback:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
continue
case input := <-manualInputCh:
manualInputCh = nil
manualInputErrCh = nil
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
@@ -367,6 +349,8 @@ waitForCallback:
}
authCode = parsed.Code
break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}

View File

@@ -0,0 +1,492 @@
package gitlab
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
DefaultBaseURL = "https://gitlab.com"
DefaultCallbackPort = 17171
defaultOAuthScope = "api read_user"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
type OAuthResult struct {
Code string
State string
Error string
}
type OAuthServer struct {
server *http.Server
port int
resultChan chan *OAuthResult
errorChan chan error
mu sync.Mutex
running bool
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
CreatedAt int64 `json:"created_at"`
ExpiresIn int `json:"expires_in"`
}
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
PublicEmail string `json:"public_email"`
}
type PersonalAccessTokenSelf struct {
ID int64 `json:"id"`
Name string `json:"name"`
Scopes []string `json:"scopes"`
UserID int64 `json:"user_id"`
}
type ModelDetails struct {
ModelProvider string `json:"model_provider"`
ModelName string `json:"model_name"`
}
type DirectAccessResponse struct {
BaseURL string `json:"base_url"`
Token string `json:"token"`
ExpiresAt int64 `json:"expires_at"`
Headers map[string]string `json:"headers"`
ModelDetails *ModelDetails `json:"model_details,omitempty"`
}
type DiscoveredModel struct {
ModelProvider string
ModelName string
}
type AuthClient struct {
httpClient *http.Client
}
func NewAuthClient(cfg *config.Config) *AuthClient {
client := &http.Client{}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &AuthClient{httpClient: client}
}
func NormalizeBaseURL(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return DefaultBaseURL
}
if !strings.Contains(value, "://") {
value = "https://" + value
}
value = strings.TrimRight(value, "/")
return value
}
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
if token == nil {
return time.Time{}
}
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
}
if token.ExpiresIn > 0 {
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
}
return time.Time{}
}
func GeneratePKCECodes() (*PKCECodes, error) {
verifierBytes := make([]byte, 32)
if _, err := rand.Read(verifierBytes); err != nil {
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
}
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
sum := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
return &PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
resultChan: make(chan *OAuthResult, 1),
errorChan: make(chan error, 1),
}
}
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("gitlab oauth server already running")
}
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", s.handleCallback)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.errorChan <- err
}
}()
time.Sleep(100 * time.Millisecond)
return nil
}
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
defer func() {
s.running = false
s.server = nil
}()
return s.server.Shutdown(ctx)
}
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case result := <-s.resultChan:
return result, nil
case err := <-s.errorChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
query := r.URL.Query()
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
s.sendResult(&OAuthResult{Error: errParam})
http.Error(w, errParam, http.StatusBadRequest)
return
}
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
if code == "" || state == "" {
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
http.Error(w, "missing code or state", http.StatusBadRequest)
return
}
s.sendResult(&OAuthResult{Code: code, State: state})
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
}
func (s *OAuthServer) sendResult(result *OAuthResult) {
select {
case s.resultChan <- result:
default:
log.Debug("gitlab oauth result channel full, dropping callback result")
}
}
func (s *OAuthServer) isPortAvailable() bool {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
if err != nil {
return false
}
_ = listener.Close()
return true
}
func RedirectURL(port int) string {
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
}
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
if pkce == nil {
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
}
if strings.TrimSpace(clientID) == "" {
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
}
baseURL = NormalizeBaseURL(baseURL)
params := url.Values{
"client_id": {strings.TrimSpace(clientID)},
"response_type": {"code"},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"scope": {defaultOAuthScope},
"state": {strings.TrimSpace(state)},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
}
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
}
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {strings.TrimSpace(clientID)},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {strings.TrimSpace(codeVerifier)},
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
form := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
if clientID = strings.TrimSpace(clientID); clientID != "" {
form.Set("client_id", clientID)
}
if secret := strings.TrimSpace(clientSecret); secret != "" {
form.Set("client_secret", secret)
}
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
}
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab token request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var token TokenResponse
if err := json.Unmarshal(body, &token); err != nil {
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
}
return &token, nil
}
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab user request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var user User
if err := json.Unmarshal(body, &user); err != nil {
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
}
return &user, nil
}
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var pat PersonalAccessTokenSelf
if err := json.Unmarshal(body, &pat); err != nil {
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
}
return &pat, nil
}
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var direct DirectAccessResponse
if err := json.Unmarshal(body, &direct); err != nil {
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
}
if direct.Headers == nil {
direct.Headers = make(map[string]string)
}
return &direct, nil
}
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
if len(metadata) == 0 {
return nil
}
models := make([]DiscoveredModel, 0, 4)
seen := make(map[string]struct{})
appendModel := func(provider, name string) {
provider = strings.TrimSpace(provider)
name = strings.TrimSpace(name)
if name == "" {
return
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
models = append(models, DiscoveredModel{
ModelProvider: provider,
ModelName: name,
})
}
if raw, ok := metadata["model_details"]; ok {
appendDiscoveredModels(raw, appendModel)
}
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
for _, key := range []string{"models", "supported_models", "discovered_models"} {
if raw, ok := metadata[key]; ok {
appendDiscoveredModels(raw, appendModel)
}
}
return models
}
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
switch typed := raw.(type) {
case map[string]any:
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
if nested, ok := typed["models"]; ok {
appendDiscoveredModels(nested, appendModel)
}
case []any:
for _, item := range typed {
appendDiscoveredModels(item, appendModel)
}
case []string:
for _, item := range typed {
appendModel("", item)
}
case string:
appendModel("", typed)
}
}
func stringValue(raw any) string {
switch typed := raw.(type) {
case string:
return strings.TrimSpace(typed)
case fmt.Stringer:
return strings.TrimSpace(typed.String())
case json.Number:
return typed.String()
case int:
return strconv.Itoa(typed)
case int64:
return strconv.FormatInt(typed, 10)
case float64:
return strconv.FormatInt(int64(typed), 10)
default:
return ""
}
}

View File

@@ -0,0 +1,138 @@
package gitlab
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
client := NewAuthClient(nil)
pkce, err := GeneratePKCECodes()
if err != nil {
t.Fatalf("GeneratePKCECodes() error = %v", err)
}
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
if err != nil {
t.Fatalf("GenerateAuthURL() error = %v", err)
}
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("Parse(authURL) error = %v", err)
}
if got := parsed.Path; got != "/oauth/authorize" {
t.Fatalf("expected /oauth/authorize path, got %q", got)
}
query := parsed.Query()
if got := query.Get("client_id"); got != "client-id" {
t.Fatalf("expected client_id, got %q", got)
}
if got := query.Get("scope"); got != defaultOAuthScope {
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
}
if got := query.Get("code_challenge_method"); got != "S256" {
t.Fatalf("expected PKCE method S256, got %q", got)
}
if got := query.Get("code_challenge"); got == "" {
t.Fatal("expected non-empty code_challenge")
}
}
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
t.Fatalf("unexpected path %q", r.URL.Path)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm() error = %v", err)
}
if got := r.Form.Get("grant_type"); got != "authorization_code" {
t.Fatalf("expected authorization_code grant, got %q", got)
}
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
t.Fatalf("expected code_verifier, got %q", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"token_type": "Bearer",
"scope": "api read_user",
"created_at": 1710000000,
"expires_in": 3600,
})
}))
defer srv.Close()
client := NewAuthClient(nil)
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
if err != nil {
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
}
if token.AccessToken != "oauth-access" {
t.Fatalf("expected access token, got %q", token.AccessToken)
}
if token.RefreshToken != "oauth-refresh" {
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
}
}
func TestExtractDiscoveredModels(t *testing.T) {
models := ExtractDiscoveredModels(map[string]any{
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
"supported_models": []any{
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
"claude-sonnet-4-5",
},
})
if len(models) != 2 {
t.Fatalf("expected 2 unique models, got %d", len(models))
}
if models[0].ModelName != "claude-sonnet-4-5" {
t.Fatalf("unexpected first model %q", models[0].ModelName)
}
if models[1].ModelName != "gpt-4.1" {
t.Fatalf("unexpected second model %q", models[1].ModelName)
}
}
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
t.Fatalf("unexpected path %q", r.URL.Path)
}
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
t.Fatalf("expected bearer token, got %q", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1710003600,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
}))
defer srv.Close()
client := NewAuthClient(nil)
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
if err != nil {
t.Fatalf("FetchDirectAccess() error = %v", err)
}
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
}
}

View File

@@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,69 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
func DoGitLabLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
"login_mode": "oauth",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo authentication successful!")
}
func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
Metadata: map[string]string{
"login_mode": "pat",
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts)
if err != nil {
fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("GitLab Duo PAT authentication successful!")
}

View File

@@ -0,0 +1,32 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
codex-header-defaults:
user-agent: " my-codex-client/1.0 "
beta-features: " feature-a,feature-b "
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
}
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
}
}

View File

@@ -101,6 +101,10 @@ type Config struct {
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
// These are used only when the client does not send its own headers.
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
@@ -150,6 +154,14 @@ type ClaudeHeaderDefaults struct {
Timeout string `yaml:"timeout" json:"timeout"`
}
// CodexHeaderDefaults configures fallback header values injected into Codex
// model requests for OAuth/file-backed auth when the client omits them.
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
type CodexHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -673,12 +685,15 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys()
// Sanitize Vertex-compatible API keys: drop entries without base-url
// Sanitize Vertex-compatible API keys.
cfg.SanitizeVertexCompatKeys()
// Sanitize Codex keys: drop entries without base-url
cfg.SanitizeCodexKeys()
// Sanitize Codex header defaults.
cfg.SanitizeCodexHeaderDefaults()
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
@@ -771,6 +786,16 @@ func payloadRawString(value any) ([]byte, bool) {
}
}
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
// configured Codex header fallback values.
func (cfg *Config) SanitizeCodexHeaderDefaults() {
if cfg == nil {
return
}
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.

View File

@@ -20,9 +20,9 @@ type VertexCompatKey struct {
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL is the base URL for the Vertex-compatible API endpoint.
// BaseURL optionally overrides the Vertex-compatible API endpoint.
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
// When empty, requests fall back to the default Vertex API base URL.
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this API key.
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
}
entry.Prefix = normalizeModelPrefix(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
if entry.BaseURL == "" {
// BaseURL is required for Vertex API key entries
continue
}
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)

View File

@@ -30,6 +30,23 @@ type OAuthCallback struct {
ErrorDescription string
}
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
// the result. The returned channels are buffered (size 1) so the goroutine can
// complete even if the caller abandons the channels.
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
inputCh := make(chan string, 1)
errCh := make(chan error, 1)
go func() {
input, err := promptFn(message)
if err != nil {
errCh <- err
return
}
inputCh <- input
}()
return inputCh, errCh
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {

View File

@@ -1,12 +1,105 @@
// Package registry provides model definitions and lookup helpers for various AI providers.
// Static model metadata is stored in model_definitions_static_data.go.
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
package registry
import (
"sort"
"strings"
)
// staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity []*ModelInfo `json:"antigravity"`
}
// GetClaudeModels returns the standard Claude model definitions.
func GetClaudeModels() []*ModelInfo {
return cloneModelInfos(getModels().Claude)
}
// GetGeminiModels returns the standard Gemini model definitions.
func GetGeminiModels() []*ModelInfo {
return cloneModelInfos(getModels().Gemini)
}
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
func GetGeminiVertexModels() []*ModelInfo {
return cloneModelInfos(getModels().Vertex)
}
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
func GetGeminiCLIModels() []*ModelInfo {
return cloneModelInfos(getModels().GeminiCLI)
}
// GetAIStudioModels returns model definitions for AI Studio.
func GetAIStudioModels() []*ModelInfo {
return cloneModelInfos(getModels().AIStudio)
}
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
}
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
}
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
}
// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
}
// GetQwenModels returns the standard Qwen model definitions.
func GetQwenModels() []*ModelInfo {
return cloneModelInfos(getModels().Qwen)
}
// GetIFlowModels returns the standard iFlow model definitions.
func GetIFlowModels() []*ModelInfo {
return cloneModelInfos(getModels().IFlow)
}
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi)
}
// GetAntigravityModels returns the standard Antigravity model definitions.
func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*ModelInfo, len(models))
for i, m := range models {
out[i] = cloneModelInfo(m)
}
return out
}
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown.
//
@@ -20,7 +113,6 @@ import (
// - qwen
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - amazonq
@@ -39,7 +131,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "aistudio":
return GetAIStudioModels()
case "codex":
return GetOpenAIModels()
return GetCodexProModels()
case "qwen":
return GetQwenModels()
case "iflow":
@@ -55,28 +147,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "amazonq":
return GetAmazonQModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
return nil
}
models := make([]*ModelInfo, 0, len(cfg))
for modelID, entry := range cfg {
if modelID == "" || entry == nil {
continue
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
Thinking: entry.Thinking,
MaxCompletionTokens: entry.MaxCompletionTokens,
})
}
sort.Slice(models, func(i, j int) bool {
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
})
return models
return GetAntigravityModels()
default:
return nil
}
@@ -89,16 +160,18 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil
}
data := getModels()
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
data.Claude,
data.Gemini,
data.Vertex,
data.GeminiCLI,
data.AIStudio,
data.CodexPro,
data.Qwen,
data.IFlow,
data.Kimi,
data.Antigravity,
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
@@ -107,20 +180,11 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
return cloneModelInfo(m)
}
}
}
// Check Antigravity static config
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
return &ModelInfo{
ID: modelID,
Thinking: cfg.Thinking,
MaxCompletionTokens: cfg.MaxCompletionTokens,
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -189,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
}
const defaultModelRegistryHookTimeout = 5 * time.Second
const modelQuotaExceededWindow = 5 * time.Minute
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
@@ -390,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
// Re-registering an existing client/model binding starts a fresh registry
// snapshot for that binding. Cooldown and suspension are transient
// scheduling state and must not survive this reconciliation step.
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
}
@@ -783,7 +787,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time
for _, registration := range r.models {
@@ -794,7 +797,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(quotaExpiredDuration)
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
@@ -929,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
@@ -951,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -1005,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
if registration, exists := r.models[modelID]; exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
@@ -1236,12 +1237,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
defer r.mutex.Unlock()
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)

View File

@@ -0,0 +1,372 @@
package registry
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
modelsFetchTimeout = 30 * time.Second
modelsRefreshInterval = 3 * time.Hour
)
var modelsURLs = []string{
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
"https://models.router-for.me/models.json",
}
//go:embed models/models.json
var embeddedModelsJSON []byte
type modelStore struct {
mu sync.RWMutex
data *staticModelsJSON
}
var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
// changedProviders contains the provider names whose model definitions changed.
type ModelRefreshCallback func(changedProviders []string)
var (
refreshCallbackMu sync.Mutex
refreshCallback ModelRefreshCallback
pendingRefreshChanges []string
)
// SetModelRefreshCallback registers a callback that is invoked when startup or
// periodic model refresh detects changes. Only one callback is supported;
// subsequent calls replace the previous callback.
func SetModelRefreshCallback(cb ModelRefreshCallback) {
refreshCallbackMu.Lock()
refreshCallback = cb
var pending []string
if cb != nil && len(pendingRefreshChanges) > 0 {
pending = append([]string(nil), pendingRefreshChanges...)
pendingRefreshChanges = nil
}
refreshCallbackMu.Unlock()
if cb != nil && len(pending) > 0 {
cb(pending)
}
}
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
}
}
// StartModelsUpdater starts a background updater that fetches models
// immediately on startup and then refreshes the model catalog every 3 hours.
// Safe to call multiple times; only one updater will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
go runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
tryStartupRefresh(ctx)
periodicRefresh(ctx)
}
func periodicRefresh(ctx context.Context) {
ticker := time.NewTicker(modelsRefreshInterval)
defer ticker.Stop()
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tryPeriodicRefresh(ctx)
}
}
}
// tryPeriodicRefresh fetches models from remote, compares with the current
// catalog, and notifies the registered callback if any provider changed.
func tryPeriodicRefresh(ctx context.Context) {
tryRefreshModels(ctx, "periodic model refresh")
}
// tryStartupRefresh fetches models from remote in the background during
// process startup. It uses the same change detection as periodic refresh so
// existing auth registrations can be updated after the callback is registered.
func tryStartupRefresh(ctx context.Context) {
tryRefreshModels(ctx, "startup model refresh")
}
func tryRefreshModels(ctx context.Context, label string) {
oldData := getModels()
parsed, url := fetchModelsFromRemote(ctx)
if parsed == nil {
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
return
}
// Detect changes before updating store.
changed := detectChangedProviders(oldData, parsed)
// Update store with new data regardless.
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = parsed
modelsCatalogStore.mu.Unlock()
if len(changed) == 0 {
log.Infof("%s completed from %s, no changes detected", label, url)
return
}
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
notifyModelRefresh(changed)
}
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
if err != nil {
cancel()
log.Debugf("models fetch request creation failed for %s: %v", url, err)
continue
}
resp, err := client.Do(req)
if err != nil {
cancel()
log.Debugf("models fetch failed from %s: %v", url, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
cancel()
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
cancel()
if err != nil {
log.Debugf("models fetch read error from %s: %v", url, err)
continue
}
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
if err := validateModelsCatalog(&parsed); err != nil {
log.Warnf("models validate failed from %s: %v", url, err)
continue
}
return &parsed, url
}
return nil, ""
}
// detectChangedProviders compares two model catalogs and returns provider names
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
// under a single "codex" provider.
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if oldData == nil || newData == nil {
return nil
}
type section struct {
provider string
oldList []*ModelInfo
newList []*ModelInfo
}
sections := []section{
{"claude", oldData.Claude, newData.Claude},
{"gemini", oldData.Gemini, newData.Gemini},
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
}
seen := make(map[string]bool, len(sections))
var changed []string
for _, s := range sections {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
changed = append(changed, s.provider)
seen[s.provider] = true
}
}
return changed
}
// modelSectionChanged reports whether two model slices differ.
func modelSectionChanged(a, b []*ModelInfo) bool {
if len(a) != len(b) {
return true
}
if len(a) == 0 {
return false
}
aj, err1 := json.Marshal(a)
bj, err2 := json.Marshal(b)
if err1 != nil || err2 != nil {
return true
}
return string(aj) != string(bj)
}
func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return
}
refreshCallbackMu.Lock()
cb := refreshCallback
if cb == nil {
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
refreshCallbackMu.Unlock()
return
}
refreshCallbackMu.Unlock()
cb(changedProviders)
}
func mergeProviderNames(existing, incoming []string) []string {
if len(incoming) == 0 {
return existing
}
seen := make(map[string]struct{}, len(existing)+len(incoming))
merged := make([]string, 0, len(existing)+len(incoming))
for _, provider := range existing {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
for _, provider := range incoming {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
}
func loadModelsFromBytes(data []byte, source string) error {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
return fmt.Errorf("%s: decode models catalog: %w", source, err)
}
if err := validateModelsCatalog(&parsed); err != nil {
return fmt.Errorf("%s: validate models catalog: %w", source, err)
}
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = &parsed
modelsCatalogStore.mu.Unlock()
return nil
}
func getModels() *staticModelsJSON {
modelsCatalogStore.mu.RLock()
defer modelsCatalogStore.mu.RUnlock()
return modelsCatalogStore.data
}
func validateModelsCatalog(data *staticModelsJSON) error {
if data == nil {
return fmt.Errorf("catalog is nil")
}
requiredSections := []struct {
name string
models []*ModelInfo
}{
{name: "claude", models: data.Claude},
{name: "gemini", models: data.Gemini},
{name: "vertex", models: data.Vertex},
{name: "gemini-cli", models: data.GeminiCLI},
{name: "aistudio", models: data.AIStudio},
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
{name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity},
}
for _, section := range requiredSections {
if err := validateModelSection(section.name, section.models); err != nil {
return err
}
}
return nil
}
func validateModelSection(section string, models []*ModelInfo) error {
if len(models) == 0 {
return fmt.Errorf("%s section is empty", section)
}
seen := make(map[string]struct{}, len(models))
for i, model := range models {
if model == nil {
return fmt.Errorf("%s[%d] is null", section, i)
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
return fmt.Errorf("%s[%d] has empty id", section, i)
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
}
seen[modelID] = struct{}{}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
return resp, nil
}
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
break
}
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return false
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
}
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes the authentication credentials (no-op for AI Studio).

View File

@@ -24,7 +24,6 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -43,7 +42,6 @@ const (
antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
@@ -55,78 +53,8 @@ const (
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -380,7 +308,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -584,7 +512,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -763,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
partsJSON, _ := json.Marshal(parts)
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
responseTemplate = string(updatedTemplate)
if role != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
responseTemplate = string(updatedTemplate)
}
if finishReason != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
responseTemplate = string(updatedTemplate)
}
if modelVersion != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
responseTemplate = string(updatedTemplate)
}
if responseID != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
responseTemplate = string(updatedTemplate)
}
if usageRaw != "" {
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
responseTemplate = string(updatedTemplate)
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
responseTemplate = string(updatedTemplate)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
responseTemplate = string(updatedTemplate)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
responseTemplate = string(updatedTemplate)
}
output := `{"response":{},"traceId":""}`
output, _ = sjson.SetRaw(output, "response", responseTemplate)
updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
output = string(updatedOutput)
if traceID != "" {
output, _ = sjson.Set(output, "traceId", traceID)
updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
output = string(updatedOutput)
}
return []byte(output)
}
@@ -952,12 +891,12 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -1115,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode
@@ -1150,168 +1089,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
}
}
// FetchAntigravityModels retrieves available models using the supplied auth.
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
for idx, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
if errReq != nil {
return fallbackAntigravityPrimaryModels()
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(baseURL); host != "" {
httpReq.Host = host
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
modelCfg := modelConfig[modelID]
// Extract displayName from upstream response, fallback to modelID
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
modelInfo := &registry.ModelInfo{
ID: modelID,
Name: modelID,
Description: displayName,
DisplayName: displayName,
Version: modelID,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Build input modalities from upstream capability flags.
inputModalities := []string{"TEXT"}
if modelData.Get("supportsImages").Bool() {
inputModalities = append(inputModalities, "IMAGE")
}
if modelData.Get("supportsVideo").Bool() {
inputModalities = append(inputModalities, "VIDEO")
}
modelInfo.SupportedInputModalities = inputModalities
modelInfo.SupportedOutputModalities = []string{"TEXT"}
// Token limits from upstream.
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
modelInfo.InputTokenLimit = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
modelInfo.OutputTokenLimit = int(maxOut)
}
// Supported generation methods (Gemini v1beta convention).
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
// Look up Thinking support from static config using upstream model name.
if modelCfg != nil {
if modelCfg.Thinking != nil {
modelInfo.Thinking = modelCfg.Thinking
}
if modelCfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
if auth == nil {
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
@@ -1499,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// if useAntigravitySchema {
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
// for _, partResult := range systemInstructionPartsResult.Array() {
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
// payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
// }
// }
// }
if strings.Contains(modelName, "claude") {
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
payloadStr = string(updated)
} else {
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
}
@@ -1733,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
}
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template := payload
template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
isImageModel := strings.Contains(modelName, "image")
@@ -1744,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
} else {
reqType = "agent"
}
template, _ = sjson.Set(template, "requestType", reqType)
template, _ = sjson.SetBytes(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.Set(template, "project", projectID)
template, _ = sjson.SetBytes(template, "project", projectID)
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
template, _ = sjson.SetBytes(template, "project", generateProjectID())
}
if isImageModel {
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
} else {
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
}
template, _ = sjson.Delete(template, "request.safetySettings")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
template, _ = sjson.Delete(template, "toolConfig")
template, _ = sjson.DeleteBytes(template, "request.safetySettings")
if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
template, _ = sjson.DeleteBytes(template, "toolConfig")
}
return []byte(template)
return template
}
func generateRequestID() string {

View File

@@ -1,90 +0,0 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -255,7 +255,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -443,7 +443,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -561,7 +561,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -1260,12 +1260,18 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw
if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
partJSON = string(updated)
}
result += "," + partJSON
}
return true
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
partJSON = string(updated)
result += "," + partJSON
}
result += "]"

View File

@@ -842,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
@@ -980,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}

View File

@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, true)
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, false)
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -273,7 +273,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if err != nil {
return nil, err
}
applyCodexHeaders(httpReq, auth, apiKey, true)
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -387,7 +387,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -432,7 +432,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
return httpReq, nil
}
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) {
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
@@ -647,7 +647,8 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if stream {
r.Header.Set("Accept", "text/event-stream")

View File

@@ -23,6 +23,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -190,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -342,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: out}
return resp, nil
}
}
@@ -385,7 +386,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
authID = auth.ID
@@ -591,7 +592,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
line := encodeCodexWebsocketAsSSE(payload)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, &param)
for i := range chunks {
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
terminateReason = "context_done"
terminateErr = ctx.Err()
return
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return dialer
}
parsedURL, errParse := url.Parse(proxyURL)
setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil {
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
log.Errorf("codex websockets executor: %v", errParse)
return dialer
}
switch parsedURL.Scheme {
switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5":
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return socksDialer.Dial(network, addr)
}
case "http", "https":
dialer.Proxy = http.ProxyURL(parsedURL)
dialer.Proxy = http.ProxyURL(setting.URL)
default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
}
return dialer
@@ -787,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
return rawJSON, headers
}
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header {
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
if headers == nil {
headers = http.Header{}
}
@@ -800,7 +810,8 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "")
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
@@ -815,7 +826,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent)
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
isAPIKey := false
if auth != nil && auth.Attributes != nil {
@@ -843,6 +854,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return "", ""
}
}
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
type statusErrWithHeaders struct {
statusErr
headers http.Header

View File

@@ -3,8 +3,13 @@ package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
@@ -28,9 +33,171 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "my-codex-client/1.0",
BetaFeatures: "feature-a,feature-b",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
}
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := http.Header{}
headers.Set("User-Agent", "existing-ua")
headers.Set("X-Codex-Beta-Features", "existing-beta")
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
}
}
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
}
}
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "sk-test"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
}))
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
if got := req.Header.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := req.Header.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
ginCtx.Request.Header = make(http.Header, len(headers))
for key, value := range headers {
ginCtx.Request.Header.Set(key, value)
}
return context.WithValue(context.Background(), "gin", ginCtx)
}
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}

View File

@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}

View File

@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}

View File

@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
}
var param any
converted := ""
var converted []byte
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: converted}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}
var chunks []string
var chunks [][]byte
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
}
}
@@ -653,6 +653,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
}
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
body = stripGitHubCopilotResponsesUnsupportedFields(body)
input := gjson.GetBytes(body, "input")
if input.Exists() {
// If input is already a string or array, keep it as-is.
@@ -825,6 +826,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
return body
}
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
// GitHub Copilot /responses rejects service_tier, so always remove it.
body, _ = sjson.DeleteBytes(body, "service_tier")
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
@@ -970,7 +977,7 @@ type githubCopilotResponsesStreamState struct {
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
@@ -1060,10 +1067,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
return out
return []byte(out)
}
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
if *param == nil {
*param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1,
@@ -1085,7 +1092,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4)
results := make([][]byte, 0, 4)
appendResult := func(chunk string) {
results = append(results, []byte(chunk))
}
ensureMessageStart := func() {
if state.MessageStarted {
return
@@ -1093,7 +1103,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
appendResult("event: message_start\ndata: " + messageStart + "\n\n")
state.MessageStarted = true
}
startTextBlockIfNeeded := func() {
@@ -1106,7 +1116,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
state.TextBlockStarted = true
}
stopTextBlockIfNeeded := func() {
@@ -1115,7 +1125,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
state.TextBlockStarted = false
state.TextBlockIndex = -1
}
@@ -1145,7 +1155,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
}
case "response.reasoning_summary_part.added":
ensureMessageStart()
@@ -1154,7 +1164,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
state.NextContentIndex++
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
case "response.reasoning_summary_text.delta":
if state.ReasoningActive {
delta := gjson.GetBytes(payload, "delta").String()
@@ -1162,14 +1172,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
}
}
case "response.reasoning_summary_part.done":
if state.ReasoningActive {
thinkingStop := `{"type":"content_block_stop","index":0}`
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
state.ReasoningActive = false
}
case "response.output_item.added":
@@ -1197,7 +1207,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
case "response.output_item.delta":
item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" {
@@ -1217,7 +1227,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.function_call_arguments.delta":
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
@@ -1234,7 +1244,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
@@ -1245,7 +1255,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
case "response.completed":
ensureMessageStart()
stopTextBlockIfNeeded()
@@ -1269,8 +1279,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
if cachedTokens > 0 {
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
}
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true
}
}

View File

@@ -132,6 +132,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
}
}
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"user text","service_tier":"default"}`)
got := normalizeGitHubCopilotResponsesInput(body)
if gjson.GetBytes(got, "service_tier").Exists() {
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
}
if gjson.GetBytes(got, "input").String() != "user text" {
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,469 @@
package executor
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != gitLabChatEndpoint {
t.Fatalf("unexpected path %q", r.URL.Path)
}
_, _ = w.Write([]byte(`"chat response"`))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
t.Fatalf("expected chat response, got %q", got)
}
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
t.Fatalf("expected resolved model, got %q", got)
}
}
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
chatCalls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case gitLabChatEndpoint:
chatCalls++
http.Error(w, "feature unavailable", http.StatusForbidden)
case gitLabCodeSuggestionsEndpoint:
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{{
"text": "fallback response",
}},
})
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"personal_access_token": "glpat-token",
"auth_method": "pat",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if chatCalls != 1 {
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
t.Fatalf("expected fallback response, got %q", got)
}
}
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{
"model":"gitlab-duo",
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
"max_tokens":128
}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotModel != "claude-sonnet-4-5" {
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
t.Fatalf("expected tool_use response, got %q", got)
}
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
t.Fatalf("expected tool name Bash, got %q", got)
}
}
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "openai",
"model_name": "gpt-5-codex",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotModel != "gpt-5-codex" {
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
}
}
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "oauth-refreshed",
"refresh_token": "oauth-refresh",
"token_type": "Bearer",
"scope": "api read_user",
"created_at": 1710000000,
"expires_in": 3600,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1710003600,
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "gitlab-auth.json",
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"oauth_client_secret": "client-secret",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
},
}
updated, err := exec.Refresh(context.Background(), auth)
if err != nil {
t.Fatalf("Refresh() error = %v", err)
}
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
t.Fatalf("expected refreshed access token, got %#v", got)
}
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("expected refreshed model metadata, got %#v", got)
}
}
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
var gotAccept, gotStreamingHeader, gotEncoding string
var gotStreamFlag bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
t.Fatalf("unexpected path %q", r.URL.Path)
}
gotAccept = r.Header.Get("Accept")
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
gotEncoding = r.Header.Get("Accept-Encoding")
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: stream_start\n"))
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
_, _ = w.Write([]byte("event: content_chunk\n"))
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
_, _ = w.Write([]byte("event: content_chunk\n"))
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
_, _ = w.Write([]byte("event: stream_end\n"))
_, _ = w.Write([]byte("data: {}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if gotAccept != "text/event-stream" {
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
}
if gotStreamingHeader != "true" {
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
}
if gotEncoding != "identity" {
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
}
if !gotStreamFlag {
t.Fatalf("expected upstream request to set stream=true")
}
if len(lines) < 4 {
t.Fatalf("expected translated stream chunks, got %d", len(lines))
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
}
last := lines[len(lines)-1]
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
t.Fatalf("expected stream terminator, got %q", last)
}
}
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
chatCalls := 0
streamCalls := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case gitLabCodeSuggestionsEndpoint:
streamCalls++
http.Error(w, "feature unavailable", http.StatusForbidden)
case gitLabChatEndpoint:
chatCalls++
_, _ = w.Write([]byte(`"chat fallback response"`))
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if streamCalls != 1 {
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
}
if chatCalls != 1 {
t.Fatalf("expected chat fallback once, got %d", chatCalls)
}
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
}
}
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
var gotPath string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
_, _ = w.Write([]byte("event: content_block_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
_, _ = w.Write([]byte("event: content_block_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
_, _ = w.Write([]byte("event: message_delta\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
_, _ = w.Write([]byte("event: message_stop\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "gitlab-duo",
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
}
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream() error = %v", err)
}
lines := collectStreamLines(t, result)
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
}
}
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
t.Helper()
lines := make([]string, 0, 8)
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
lines = append(lines, string(chunk.Payload))
}
return lines
}
func readBody(t *testing.T, r *http.Request) []byte {
t.Helper()
defer func() { _ = r.Body.Close() }()
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
return body
}

View File

@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.

View File

@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)

View File

@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
"cli": "amazonq",
}
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
if len(chunk) == 0 {
return
}
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
}
// retryConfig holds configuration for socket retry logic.
// Based on kiro2Api Python implementation patterns.
type retryConfig struct {
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send tool input as delta
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Close block
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
hasToolUses = true
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
messageStartSent = true
}
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
lastReportedOutputTokens = currentOutputTokens
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
continue
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Send thinking delta
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
accumulatedThinkingContent.WriteString(thinkingText)
}
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isThinkingBlockOpen = false
}
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
accumulatedThinkingContent.WriteString(processContent)
}
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isThinkingBlockOpen = false
}
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Send text delta
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Close text block before entering thinking
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send input_json_delta with the tool input
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Accumulate for token counting
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
if tu.Input != nil {
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send message_stop event separately
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// reporter.publish is called via defer
}

View File

@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -205,6 +205,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return nil, err
}
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
@@ -286,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -326,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
usageJSON := buildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
}
// Refresh is a no-op for API-key based compatibility providers.

View File

@@ -2,17 +2,15 @@ package executor
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
// Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport {
if proxyURL == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
parsedURL, errParse := url.Parse(proxyURL)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
return nil
}
var transport *http.Transport
// Handle different proxy schemes
if parsedURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil
}
return transport
}

View File

@@ -0,0 +1,30 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}

View File

@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -12,6 +12,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
rawJSON := inputRawJSON
// system instruction
systemInstructionJSON := ""
var systemInstructionJSON []byte
hasSystemInstruction := false
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstructionJSON = `{"role":"user","parts":[]}`
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
partJSON := `{}`
partJSON := []byte(`{}`)
if systemPrompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
}
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
hasSystemInstruction = true
}
}
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
}
// contents
contentsJSON := "[]"
contentsJSON := []byte(`[]`)
hasContents := false
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
toolNameByID := make(map[string]string)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if role == "assistant" {
role = "model"
}
clientContentJSON := `{"role":"","parts":[]}`
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
clientContentJSON := []byte(`{"role":"","parts":[]}`)
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// Valid signature, send as thought block
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "thought", true)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
if thinkingText != "" {
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
}
if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
// Skip empty text parts to avoid Gemini API error:
@@ -159,9 +164,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if prompt == "" {
continue
}
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
@@ -170,6 +175,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String()
if functionID != "" && functionName != "" {
toolNameByID[functionID] = functionName
}
// Handle both object and string input formats
var argsRaw string
if argsResult.IsObject() {
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
if argsRaw != "" {
partJSON := `{}`
partJSON := []byte(`{}`)
// Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
}
if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
}
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
funcName, ok := toolNameByID[toolCallID]
if !ok {
// Fallback: derive a semantic name from the ID by stripping
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
// Only use the raw ID as a last resort when the heuristic produces an empty string.
parts := strings.Split(toolCallID, "-")
if len(parts) > 2 {
funcName = strings.Join(parts[:len(parts)-2], "-")
}
if funcName == "" {
funcName = toolCallID
}
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
}
functionResponseResult := contentResult.Get("content")
functionResponseJSON := `{}`
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
functionResponseJSON := []byte(`{}`)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", funcName)
responseData := ""
if functionResponseResult.Type == gjson.String {
responseData = functionResponseResult.String()
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
filteredJSON := []byte(`[]`)
imagePartsJSON := []byte(`[]`)
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := []byte(`[]`)
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
} else {
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
sourceResult := contentResult.Get("source")
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
}
}
// Reorder parts for 'model' role to ensure thinking block is first
if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts")
partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
}
}
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Skip messages with empty parts array to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
partsCheck := gjson.Get(clientContentJSON, "parts")
partsCheck := gjson.GetBytes(clientContentJSON, "parts")
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
continue
}
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
partJSON := `{}`
partJSON := []byte(`{}`)
if prompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt)
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true
}
}
}
// tools
toolsJSON := ""
var toolsJSON []byte
toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]`
toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
@@ -378,23 +396,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
// Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
for toolKey := range gjson.Parse(tool).Map() {
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
for toolKey := range gjson.ParseBytes(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) {
continue
}
tool, _ = sjson.Delete(tool, toolKey)
tool, _ = sjson.DeleteBytes(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++
}
}
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
out := []byte(`{"model":"","request":{"contents":[]}}`)
out, _ = sjson.SetBytes(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0
@@ -408,27 +426,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if hasSystemInstruction {
// Append hint as a new part to existing system instruction
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
} else {
// Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}`
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true
}
}
if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
}
if hasContents {
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
}
if toolDeclCount > 0 {
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
}
// tool_choice
@@ -445,15 +463,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
@@ -464,8 +482,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
@@ -477,28 +495,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
}
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
return outBytes
return out
}

View File

@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "get_weather-call-123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
outputStr := string(output)
// Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Error("functionResponse should exist")
}
if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
}
if funcResp.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"name": "Glob",
"input": {"pattern": "**/*.py"}
},
{
"type": "tool_use",
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"name": "Bash",
"input": {"command": "ls"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "file1.py\nfile2.py"
},
{
"type": "tool_result",
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"content": "total 10"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp0.Exists() {
t.Fatal("first functionResponse should exist")
}
if got := funcResp0.Get("name").String(); got != "Glob" {
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
}
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
if !funcResp1.Exists() {
t.Fatal("second functionResponse should exist")
}
if got := funcResp1.Get("name").String(); got != "Bash" {
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "Read-1773420180464065165-1327",
"name": "Read",
"input": {"file_path": "/tmp/test.py"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-1773420180464065165-1327",
"content": "file content here"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "Read" {
t.Errorf("Expected name 'Read', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "get_weather" {
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "result data"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
got := funcResp.Get("name").String()
if got == "" {
t.Error("functionResponse.name must not be empty")
}
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
@@ -63,8 +64,8 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
@@ -77,44 +78,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params := (*param).(*Params)
if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := ""
output := make([]byte, 0, 256)
// Only send final events if we have actually output content
if params.HasContent {
appendFinalEvents(params, &output, true)
return []string{
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
return [][]byte{output}
}
return []string{}
return [][]byte{}
}
output := ""
output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !params.HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
appendEvent("message_start", string(messageStartTemplate))
params.HasFirstResponse = true
}
@@ -144,15 +145,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else {
// Transition from another state to thinking
@@ -163,19 +162,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.ResponseType = 2 // Set state to thinking
params.HasContent = true
// Start accumulating thinking text for signature caching
@@ -188,9 +182,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if params.ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else {
// Transition from another state to text content
@@ -201,19 +194,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
if partTextResult.String() != "" {
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.ResponseType = 1 // Set state to content
params.HasContent = true
}
@@ -229,9 +217,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Handle state transitions when switching to function calls
// Close any existing function call block first
if params.ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
params.ResponseType = 0
}
@@ -245,26 +231,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Close any other existing content block
if params.ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
appendEvent("content_block_delta", string(data))
}
params.ResponseType = 3
params.HasContent = true
@@ -296,10 +277,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
appendFinalEvents(params, &output, false)
}
return []string{output}
return [][]byte{output}
}
func appendFinalEvents(params *Params, output *string, force bool) {
func appendFinalEvents(params *Params, output *[]byte, force bool) {
if params.HasSentFinalEvents {
return
}
@@ -314,9 +295,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
}
if params.ResponseType != 0 {
*output = *output + "event: content_block_stop\n"
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
*output = *output + "\n\n\n"
*output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
params.ResponseType = 0
}
@@ -329,18 +308,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
}
}
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n"
*output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
params.HasSentFinalEvents = true
}
@@ -369,8 +346,8 @@ func resolveStopReason(params *Params) string {
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON
modelName := gjson.GetBytes(requestRawJSON, "model").String()
@@ -388,15 +365,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
}
}
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
@@ -407,7 +384,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
if contentArrayInitialized {
return
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
contentArrayInitialized = true
}
@@ -423,9 +400,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return
}
ensureContentArray()
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
textBuilder.Reset()
}
@@ -434,12 +411,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return
}
ensureContentArray()
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if thinkingSignature != "" {
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
thinkingBuilder.Reset()
thinkingSignature = ""
}
@@ -475,16 +452,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
}
ensureContentArray()
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
continue
}
}
@@ -508,17 +485,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
}
}
}
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
if promptTokens == 0 && outputTokens == 0 {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
responseJSON, _ = sjson.Delete(responseJSON, "usage")
responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
}
}
return responseJSON
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -34,10 +34,10 @@ import (
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", modelName)
template := `{"project":"","request":{},"model":""}`
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
template = string(templateBytes)
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
template = string(templateBytes)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
@@ -138,30 +139,47 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponseRaw(response gjson.Result) string {
// fallbackName is used when the response's own name is empty.
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
if response.IsObject() && gjson.Valid(response.Raw) {
return response.Raw
raw := response.Raw
name := response.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
raw = string(updated)
}
return raw
}
log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse")
if funcResp.Exists() {
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.Set(fr, "functionResponse.id", id)
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
name := funcResp.Get("name").String()
if strings.TrimSpace(name) == "" {
name = fallbackName
}
return fr
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
}
return string(fr)
}
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
return fr
useName := fallbackName
if useName == "" {
useName = "unknown"
}
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
return string(fr)
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -188,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
}
// Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}`
contentsWrapper := []byte(`{"contents":[]}`)
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -211,30 +229,26 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
@@ -243,25 +257,26 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
functionCallsCount := 0
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsCount++
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if functionCallsCount > 0 {
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount,
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
@@ -270,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
} else {
// Non-model content (user, etc.)
@@ -278,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
return true
@@ -290,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response)
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
return result, nil
return string(result), nil
}

View File

@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
// When the Amp client sends functionResponse with an empty name,
// fixCLIToolResponse should backfill it from the corresponding functionCall.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
// Parallel function calls: both responses have empty names.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
parts := funcContent.Get("parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
}
name0 := parts[0].Get("functionResponse.name").String()
name1 := parts[1].Get("functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first response name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
}
}
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
// When functionResponse already has a valid name, it should be preserved.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
// If there are more function responses than calls, unmatched extras are discarded by grouping.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// First response should be backfilled from the call
name0 := funcContent.Get("parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
}
}
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
// Two sequential function call groups should be matched FIFO.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContents []gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContents = append(funcContents, c)
}
}
if len(funcContents) != 2 {
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
}
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}

View File

@@ -8,8 +8,8 @@ package gemini
import (
"bytes"
"context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -29,8 +29,8 @@ import (
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
// - [][]byte: The transformed response data in Gemini API format.
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
chunk = restoreUsageMetadata(chunk)
}
} else {
chunkTemplate := "[]"
chunkTemplate := []byte("[]")
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
}
}
}
chunk = []byte(chunkTemplate)
chunk = chunkTemplate
}
return []string{string(chunk)}
return [][]byte{chunk}
}
return []string{}
return [][]byte{}
}
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing the response data.
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
return string(chunk)
return chunk
}
return string(rawJSON)
return rawJSON
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.

View File

@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
if result != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
if string(result) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
}
})
}
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if results[0] != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
if string(results[0]) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
}
})
}

View File

@@ -354,31 +354,35 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw = string(fnRawBytes)
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw = string(fnRawBytes)
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {

View File

@@ -44,8 +44,8 @@ var functionCallIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
@@ -54,15 +54,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
return [][]byte{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
@@ -71,14 +71,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
}
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
@@ -90,21 +90,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
@@ -141,33 +141,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
@@ -181,16 +181,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
@@ -212,11 +212,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
} else {
finishReason = "stop"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
}
return []string{template}
return [][]byte{template}
}
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
@@ -231,11 +231,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
return []byte{}
}

View File

@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result1) != 1 {
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
}
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
}
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result2) != 1 {
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
}
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr2 != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
}
// Verify native_finish_reason is lowercase upstream value
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
if nfr2 != "stop" {
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
}
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "stop" (no tool calls were made)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "stop" {
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
}
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "max_tokens"
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "max_tokens" {
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
}
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
}
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param)
// Verify no finish_reason on intermediate chunk
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
}
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify no finish_reason on intermediate chunk
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/tidwall/gjson"
)
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)

View File

@@ -8,7 +8,7 @@ import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
"github.com/tidwall/sjson"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
)
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
@@ -23,15 +23,13 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap each converted response in a "response" object to match Gemini CLI API structure
newOutputs := make([]string, 0)
newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
}
return newOutputs
}
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap the converted response in a "response" object to match Gemini CLI API structure
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
return translatorcommon.WrapGeminiCLIResponse(out)
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return GeminiTokenCount(ctx, count)
}

View File

@@ -63,7 +63,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -87,20 +87,20 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
var pendingToolIDs []string
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Generation config extraction from Gemini format
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
// Max output tokens configuration
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := genConfig.Get("topP"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
@@ -110,7 +110,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
}
}
// Include thoughts configuration for reasoning process visibility
@@ -132,30 +132,30 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", level)
}
} else {
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default:
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -169,37 +169,37 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
level, ok := thinking.ConvertBudgetToLevel(budget)
if ok {
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", level)
}
}
} else {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
}
}
}
@@ -220,9 +220,9 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
})
if systemText.Len() > 0 {
// Create system message in Claude Code format
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}`
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
systemMessage := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
systemMessage, _ = sjson.SetBytes(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMessage)
}
}
}
@@ -245,42 +245,42 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Create message structure in Claude Code format
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Text content conversion
if text := part.Get("text"); text.Exists() {
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
textContent := []byte(`{"type":"text","text":""}`)
textContent, _ = sjson.SetBytes(textContent, "text", text.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true
}
// Function call (from model/assistant) conversion to tool use
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
// Generate a unique tool ID and enqueue it for later matching
// with the corresponding functionResponse
toolID := genToolCallID()
pendingToolIDs = append(pendingToolIDs, toolID)
toolUse, _ = sjson.Set(toolUse, "id", toolID)
toolUse, _ = sjson.SetBytes(toolUse, "id", toolID)
if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String())
toolUse, _ = sjson.SetBytes(toolUse, "name", name.String())
}
if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(args.Raw))
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
return true
}
// Function response (from user) conversion to tool result
if fr := part.Get("functionResponse"); fr.Exists() {
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
// Attach the oldest queued tool_id to pair the response
// with its call. If the queue is empty, generate a new id.
@@ -293,41 +293,41 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
// Fallback: generate new ID if no pending tool_use found
toolID = genToolCallID()
}
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID)
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", toolID)
// Extract result content from the function response
if result := fr.Get("response.result"); result.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", result.String())
toolResult, _ = sjson.SetBytes(toolResult, "content", result.String())
} else if response := fr.Get("response"); response.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", response.Raw)
toolResult, _ = sjson.SetBytes(toolResult, "content", response.Raw)
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolResult)
return true
}
// Image content (inline_data) conversion to Claude Code format
if inlineData := part.Get("inline_data"); inlineData.Exists() {
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imageContent := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String())
imageContent, _ = sjson.SetBytes(imageContent, "source.media_type", mimeType.String())
}
if data := inlineData.Get("data"); data.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.data", data.String())
imageContent, _ = sjson.SetBytes(imageContent, "source.data", data.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent)
msg, _ = sjson.SetRawBytes(msg, "content.-1", imageContent)
return true
}
// File data conversion to text content with file info
if fileData := part.Get("file_data"); fileData.Exists() {
// For file data, we'll convert to text content with file info
textContent := `{"type":"text","text":""}`
textContent := []byte(`{"type":"text","text":""}`)
fileInfo := "File: " + fileData.Get("file_uri").String()
if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
fileInfo += " (Type: " + mimeType.String() + ")"
}
textContent, _ = sjson.Set(textContent, "text", fileInfo)
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
textContent, _ = sjson.SetBytes(textContent, "text", fileInfo)
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true
}
@@ -336,8 +336,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Only add message if it has content
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRaw(out, "messages.-1", msg)
if contentArray := gjson.GetBytes(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
}
return true
@@ -351,29 +351,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool {
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
anthropicTool := `{"name":"","description":"","input_schema":{}}`
anthropicTool := []byte(`{"name":"","description":"","input_schema":{}}`)
if name := funcDecl.Get("name"); name.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", name.String())
}
if desc := funcDecl.Get("description"); desc.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", desc.String())
}
if params := funcDecl.Get("parameters"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
cleaned := []byte(params.Raw)
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
cleaned := []byte(params.Raw)
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
}
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value())
anthropicTools = append(anthropicTools, gjson.ParseBytes(anthropicTool).Value())
return true
})
}
@@ -381,7 +381,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
})
if len(anthropicTools) > 0 {
out, _ = sjson.Set(out, "tools", anthropicTools)
out, _ = sjson.SetBytes(out, "tools", anthropicTools)
}
}
@@ -391,27 +391,27 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() {
case "AUTO":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "NONE":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"none"}`))
case "ANY":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
}
}
}
// Stream setting configuration
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Convert tool parameter types to lowercase for Claude Code compatibility
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
}
return []byte(out)
return out
}

View File

@@ -9,10 +9,10 @@ import (
"bufio"
"bytes"
"context"
"fmt"
"strings"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -30,7 +30,7 @@ type ConvertAnthropicResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
LastStorageOutput []byte
IsStreaming bool
// Streaming state for tool_use assembly
@@ -52,8 +52,8 @@ type ConvertAnthropicResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertAnthropicResponseToGeminiParams{
Model: modelName,
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,24 +71,24 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
eventType := root.Get("type").String()
// Base Gemini response template with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
// Map Claude model names back to Gemini model names
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
}
// Set creation time to current time if not provided
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
}
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
switch eventType {
case "message_start":
@@ -97,7 +97,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
}
return []string{}
return [][]byte{}
case "content_block_start":
// Start of a content block - record tool_use name by index for functionCall assembly
@@ -112,7 +112,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
}
}
return []string{}
return [][]byte{}
case "content_block_delta":
// Handle content delta (text, thinking, or tool use arguments)
@@ -123,16 +123,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
case "text_delta":
// Regular text content delta for normal response text
if text := delta.Get("text"); text.Exists() && text.String() != "" {
textPart := `{"text":""}`
textPart, _ = sjson.Set(textPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart)
textPart := []byte(`{"text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", textPart)
}
case "thinking_delta":
// Thinking/reasoning content delta for models with reasoning capabilities
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
thinkingPart := `{"thought":true,"text":""}`
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart)
thinkingPart := []byte(`{"thought":true,"text":""}`)
thinkingPart, _ = sjson.SetBytes(thinkingPart, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", thinkingPart)
}
case "input_json_delta":
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
@@ -149,10 +149,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if pj := delta.Get("partial_json"); pj.Exists() {
b.WriteString(pj.String())
}
return []string{}
return [][]byte{}
}
}
return []string{template}
return [][]byte{template}
case "content_block_stop":
// End of content block - finalize tool calls if any
@@ -170,16 +170,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
}
if name != "" || argsTrim != "" {
functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" {
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", name)
}
if argsTrim != "" {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim))
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// cleanup used state for this index
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
@@ -187,9 +187,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
}
return []string{template}
return [][]byte{template}
}
return []string{}
return [][]byte{}
case "message_delta":
// Handle message-level changes (like stop reason and usage information)
@@ -197,15 +197,15 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
switch stopReason.String() {
case "end_turn":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "tool_use":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "max_tokens":
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "MAX_TOKENS")
case "stop_sequence":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
default:
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
}
}
}
@@ -216,35 +216,35 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
template, _ = sjson.SetBytes(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
template, _ = sjson.SetBytes(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
}
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
return []string{template}
return [][]byte{template}
case "message_stop":
// Final message with usage information - no additional output needed
return []string{}
return [][]byte{}
case "error":
// Handle error responses and convert to Gemini error format
errorMsg := root.Get("error.message").String()
@@ -253,13 +253,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
// Create error response in Gemini format
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg)
return []string{errorResponse}
errorResponse := []byte(`{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`)
errorResponse, _ = sjson.SetBytes(errorResponse, "error.message", errorMsg)
return [][]byte{errorResponse}
default:
// Unknown event type, return empty response
return []string{}
return [][]byte{}
}
}
@@ -275,13 +275,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Base Gemini response template for non-streaming with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
streamingEvents := make([][]byte, 0)
@@ -304,15 +304,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
LastStorageOutput: nil,
IsStreaming: false,
ToolUseNames: nil,
ToolUseArgs: nil,
}
// Process each streaming event and collect parts
var allParts []string
var finalUsageJSON string
var allParts [][]byte
var finalUsageJSON []byte
var responseID string
var createdAt int64
@@ -360,15 +360,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "text_delta":
// Process regular text content
if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
partJSON := []byte(`{"text":""}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "thinking_delta":
// Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
partJSON := []byte(`{"thought":true,"text":""}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "input_json_delta":
@@ -402,12 +402,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
}
}
if name != "" || argsTrim != "" {
functionCallJSON := `{"functionCall":{"name":"","args":{}}}`
functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" {
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name)
functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.name", name)
}
if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim))
}
allParts = append(allParts, functionCallJSON)
// cleanup used state for this index
@@ -422,35 +422,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "message_delta":
// Extract final usage information using sjson for token counts and metadata
if usage := root.Get("usage"); usage.Exists() {
usageJSON := `{}`
usageJSON := []byte(`{}`)
// Basic token counts for prompt and completion
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
usageJSON, _ = sjson.SetBytes(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
usageJSON, _ = sjson.SetBytes(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
finalUsageJSON = usageJSON
}
@@ -459,10 +459,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set response metadata
if responseID != "" {
template, _ = sjson.Set(template, "responseId", responseID)
template, _ = sjson.SetBytes(template, "responseId", responseID)
}
if createdAt > 0 {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
}
// Consolidate consecutive text parts and thinking parts for cleaner output
@@ -470,35 +470,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array
if len(consolidatedParts) > 0 {
partsJSON := "[]"
partsJSON := []byte(`[]`)
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
partsJSON, _ = sjson.SetRawBytes(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts", partsJSON)
}
// Set usage metadata
if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
if len(finalUsageJSON) > 0 {
template, _ = sjson.SetRawBytes(template, "usageMetadata", finalUsageJSON)
}
return template
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
// This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []string) []string {
func consolidateParts(parts [][]byte) [][]byte {
if len(parts) == 0 {
return parts
}
var consolidated []string
var consolidated [][]byte
var currentTextPart strings.Builder
var currentThoughtPart strings.Builder
var hasText, hasThought bool
@@ -506,8 +506,8 @@ func consolidateParts(parts []string) []string {
flushText := func() {
// Flush accumulated text content to the consolidated parts array
if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
textPartJSON := []byte(`{"text":""}`)
textPartJSON, _ = sjson.SetBytes(textPartJSON, "text", currentTextPart.String())
consolidated = append(consolidated, textPartJSON)
currentTextPart.Reset()
hasText = false
@@ -517,8 +517,8 @@ func consolidateParts(parts []string) []string {
flushThought := func() {
// Flush accumulated thinking content to the consolidated parts array
if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
thoughtPartJSON := []byte(`{"thought":true,"text":""}`)
thoughtPartJSON, _ = sjson.SetBytes(thoughtPartJSON, "text", currentThoughtPart.String())
consolidated = append(consolidated, thoughtPartJSON)
currentThoughtPart.Reset()
hasThought = false
@@ -526,7 +526,7 @@ func consolidateParts(parts []string) []string {
}
for _, partJSON := range parts {
part := gjson.Parse(partJSON)
part := gjson.ParseBytes(partJSON)
if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part
flushText()

View File

@@ -61,7 +61,7 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude Code API template with default max_tokens value
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -79,20 +79,20 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
@@ -100,13 +100,13 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -128,19 +128,19 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens configuration with fallback to default value
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := root.Get("top_p"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
@@ -152,15 +152,15 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
}
} else {
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()})
out, _ = sjson.SetBytes(out, "stop_sequences", []string{stop.String()})
}
}
// Stream configuration to enable or disable streaming responses
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
@@ -173,39 +173,39 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
switch role {
case "system":
if systemMessageIndex == -1 {
systemMsg := `{"role":"user","content":[]}`
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg)
systemMsg := []byte(`{"role":"user","content":[]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex
messageIndex++
}
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", contentResult.String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
}
return true
})
}
case "user", "assistant":
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
// Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
part := `{"type":"text","text":""}`
part, _ = sjson.Set(part, "text", contentResult.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{"type":"text","text":""}`)
part, _ = sjson.SetBytes(part, "text", contentResult.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(claudePart))
}
return true
})
@@ -221,9 +221,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
function := toolCall.Get("function")
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.SetBytes(toolUse, "id", toolCallID)
toolUse, _ = sjson.SetBytes(toolUse, "name", function.Get("name").String())
// Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() {
@@ -231,24 +231,24 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
}
return true
})
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++
case "tool":
@@ -256,15 +256,15 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
toolCallID := message.Get("tool_call_id").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg := []byte(`{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`)
msg, _ = sjson.SetBytes(msg, "content.0.tool_use_id", toolCallID)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
msg, _ = sjson.SetRawBytes(msg, "content.0.content", []byte(toolResultContent))
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
msg, _ = sjson.SetBytes(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++
}
return true
@@ -277,25 +277,25 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" {
function := tool.Get("function")
anthropicTool := `{"name":"","description":""}`
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
anthropicTool := []byte(`{"name":"","description":""}`)
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", function.Get("description").String())
// Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
}
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
out, _ = sjson.SetRawBytes(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
}
return true
})
if !hasAnthropicTools {
out, _ = sjson.Delete(out, "tools")
out, _ = sjson.DeleteBytes(out, "tools")
}
}
@@ -308,31 +308,31 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none":
// Don't set tool_choice, Claude Code will not use tools
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
case gjson.JSON:
// Specific tool choice mapping
if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"type":"tool","name":""}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
toolChoiceJSON := []byte(`{"type":"tool","name":""}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
return out
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
return string(textPart)
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
@@ -345,10 +345,10 @@ func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
docPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
docPart, _ = sjson.SetBytes(docPart, "source.media_type", mediaType)
docPart, _ = sjson.SetBytes(docPart, "source.data", data)
return string(docPart)
}
}
}
@@ -373,15 +373,15 @@ func convertOpenAIImageURLToClaudePart(imageURL string) string {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
imagePart := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
imagePart, _ = sjson.SetBytes(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.SetBytes(imagePart, "source.data", parts[1])
return string(imagePart)
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
imagePart := []byte(`{"type":"image","source":{"type":"url","url":""}}`)
imagePart, _ = sjson.SetBytes(imagePart, "source.url", imageURL)
return string(imagePart)
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
@@ -394,28 +394,28 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
}
if content.IsArray() {
claudeContent := "[]"
claudeContent := []byte("[]")
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.String())
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
return string(claudeContent), true
}
return content.Raw, false
@@ -424,9 +424,9 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
claudeContent := []byte("[]")
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
return string(claudeContent), true
}
return content.Raw, false
}

View File

@@ -48,8 +48,8 @@ type ToolCallAccumulator struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
var localParam any
if param == nil {
param = &localParam
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,19 +71,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
eventType := root.Get("type").String()
// Base OpenAI streaming response template
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`)
// Set model
if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
}
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
}
switch eventType {
@@ -93,19 +93,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
// Set initial role to assistant for the response
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
// Initialize tool calls accumulator for tracking tool call progress
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
}
return []string{template}
return [][]byte{template}
case "content_block_start":
// Start of a content block (text, tool use, or reasoning)
@@ -128,10 +128,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
// Don't output anything yet - wait for complete tool call
return []string{}
return [][]byte{}
}
}
return []string{}
return [][]byte{}
case "content_block_delta":
// Handle content delta (text, tool use arguments, or reasoning content)
@@ -143,13 +143,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "text_delta":
// Text content delta - send incremental text updates
if text := delta.Get("text"); text.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.content", text.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.content", text.String())
hasContent = true
}
case "thinking_delta":
// Accumulate reasoning/thinking content
if thinking := delta.Get("thinking"); thinking.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", thinking.String())
hasContent = true
}
case "input_json_delta":
@@ -163,13 +163,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
}
// Don't output anything yet - wait for complete tool call
return []string{}
return [][]byte{}
}
}
if hasContent {
return []string{template}
return [][]byte{template}
} else {
return []string{}
return [][]byte{}
}
case "content_block_stop":
@@ -182,26 +182,26 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" {
arguments = "{}"
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
// Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
return []string{template}
return [][]byte{template}
}
}
return []string{}
return [][]byte{}
case "message_delta":
// Handle message-level changes including stop reason and usage
if delta := root.Get("delta"); delta.Exists() {
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
}
}
@@ -211,34 +211,34 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
return []string{template}
return [][]byte{template}
case "message_stop":
// Final message event - no additional output needed
return []string{}
return [][]byte{}
case "ping":
// Ping events for keeping connection alive - no output needed
return []string{}
return [][]byte{}
case "error":
// Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() {
errorJSON := `{"error":{"message":"","type":""}}`
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
return []string{errorJSON}
errorJSON := []byte(`{"error":{"message":"","type":""}}`)
errorJSON, _ = sjson.SetBytes(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.SetBytes(errorJSON, "error.type", errorData.Get("type").String())
return [][]byte{errorJSON}
}
return []string{}
return [][]byte{}
default:
// Unknown event type - ignore
return []string{}
return [][]byte{}
}
}
@@ -270,8 +270,8 @@ func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
chunks := make([][]byte, 0)
lines := bytes.Split(rawJSON, []byte("\n"))
@@ -283,7 +283,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
// Base OpenAI non-streaming response template
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
out := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
var messageID string
var model string
@@ -370,28 +370,28 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
}
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.Set(out, "id", messageID)
out, _ = sjson.Set(out, "created", createdAt)
out, _ = sjson.Set(out, "model", model)
out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.SetBytes(out, "created", createdAt)
out, _ = sjson.SetBytes(out, "model", model)
// Set message content by combining all text parts
messageContent := strings.Join(contentParts, "")
out, _ = sjson.Set(out, "choices.0.message.content", messageContent)
out, _ = sjson.SetBytes(out, "choices.0.message.content", messageContent)
// Add reasoning content if available (following OpenAI reasoning format)
if len(reasoningParts) > 0 {
reasoningContent := strings.Join(reasoningParts, "")
// Add reasoning as a separate field in the message
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent)
out, _ = sjson.SetBytes(out, "choices.0.message.reasoning", reasoningContent)
}
// Set tool calls if any were accumulated during processing
@@ -417,19 +417,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
out, _ = sjson.SetBytes(out, idPath, accumulator.ID)
out, _ = sjson.SetBytes(out, typePath, "function")
out, _ = sjson.SetBytes(out, namePath, accumulator.Name)
out, _ = sjson.SetBytes(out, argumentsPath, arguments)
toolCallsCount++
}
if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", "tool_calls")
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
return out

View File

@@ -49,7 +49,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -67,20 +67,20 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
@@ -88,13 +88,13 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -114,15 +114,15 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens
if mot := root.Get("max_output_tokens"); mot.Exists() {
out, _ = sjson.Set(out, "max_tokens", mot.Int())
out, _ = sjson.SetBytes(out, "max_tokens", mot.Int())
}
// Stream
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// instructions -> as a leading message (use role user for Claude API compatibility)
instructionsText := ""
@@ -130,9 +130,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
instructionsText = instr.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
}
}
@@ -156,9 +156,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
instructionsText = builder.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
extractedFromSystem = true
}
}
@@ -193,9 +193,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if t := part.Get("text"); t.Exists() {
txt := t.String()
textAggregate.WriteString(txt)
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", txt)
partsJSON = append(partsJSON, contentPart)
contentPart := []byte(`{"type":"text","text":""}`)
contentPart, _ = sjson.SetBytes(contentPart, "text", txt)
partsJSON = append(partsJSON, string(contentPart))
}
if ptype == "input_text" {
role = "user"
@@ -208,7 +208,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
url = part.Get("url").String()
}
if url != "" {
var contentPart string
var contentPart []byte
if strings.HasPrefix(url, "data:") {
trimmed := strings.TrimPrefix(url, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
@@ -221,16 +221,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1]
}
if data != "" {
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
contentPart = []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
}
} else {
contentPart = `{"type":"image","source":{"type":"url","url":""}}`
contentPart, _ = sjson.Set(contentPart, "source.url", url)
contentPart = []byte(`{"type":"image","source":{"type":"url","url":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.url", url)
}
if contentPart != "" {
partsJSON = append(partsJSON, contentPart)
if len(contentPart) > 0 {
partsJSON = append(partsJSON, string(contentPart))
if role == "" {
role = "user"
}
@@ -252,10 +252,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
contentPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
partsJSON = append(partsJSON, string(contentPart))
if role == "" {
role = "user"
}
@@ -280,24 +280,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
msg, _ = sjson.DeleteBytes(msg, "content")
textPart := gjson.Parse(partsJSON[0])
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String())
msg, _ = sjson.SetBytes(msg, "content", textPart.Get("text").String())
} else {
for _, partJSON := range partsJSON {
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON)
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
}
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
} else if textAggregate.Len() > 0 || role == "system" {
msg := `{"role":"","content":""}`
msg, _ = sjson.Set(msg, "role", role)
msg, _ = sjson.Set(msg, "content", textAggregate.String())
out, _ = sjson.SetRaw(out, "messages.-1", msg)
msg := []byte(`{"role":"","content":""}`)
msg, _ = sjson.SetBytes(msg, "role", role)
msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
}
case "function_call":
@@ -309,31 +309,31 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
name := item.Get("name").String()
argsStr := item.Get("arguments").String()
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name)
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.SetBytes(toolUse, "id", callID)
toolUse, _ = sjson.SetBytes(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
}
}
asst := `{"role":"assistant","content":[]}`
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse)
out, _ = sjson.SetRaw(out, "messages.-1", asst)
asst := []byte(`{"role":"assistant","content":[]}`)
asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
case "function_call_output":
// Map to user tool_result
callID := item.Get("call_id").String()
outputStr := item.Get("output").String()
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.Set(toolResult, "content", outputStr)
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
usr := `{"role":"user","content":[]}`
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult)
out, _ = sjson.SetRaw(out, "messages.-1", usr)
usr := []byte(`{"role":"user","content":[]}`)
usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
}
return true
})
@@ -341,27 +341,27 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := "[]"
toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := `{"name":"","description":"","input_schema":{}}`
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.Set(tJSON, "name", n.String())
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
}
if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.Set(tJSON, "description", d.String())
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
}
if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON)
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
return true
})
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "tools", toolsJSON)
}
}
@@ -371,23 +371,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String:
switch toolChoice.String() {
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "none":
// Leave unset; implies no tools
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
return out
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -50,12 +51,12 @@ func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
return nil
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
func emitEvent(event string, payload []byte) []byte {
return translatorcommon.SSEEventData(event, payload)
}
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
}
@@ -63,12 +64,12 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// Expect `data: {..}` from Claude clients
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON)
ev := root.Get("type").String()
var out []string
var out [][]byte
nextSeq := func() int { st.Seq++; return st.Seq }
@@ -105,16 +106,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
}
// response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
created, _ = sjson.SetBytes(created, "response.id", st.ResponseID)
created, _ = sjson.SetBytes(created, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.created", created))
// response.in_progress
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
inprog := []byte(`{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`)
inprog, _ = sjson.SetBytes(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.SetBytes(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.SetBytes(inprog, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.in_progress", inprog))
}
case "content_block_start":
@@ -128,25 +129,25 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// open message item + content part
st.InTextBlock = true
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.added", item))
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID)
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.added", part))
} else if typ == "tool_use" {
st.InFuncBlock = true
st.CurrentFCID = cb.Get("id").String()
name := cb.Get("name").String()
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.Set(item, "item.name", name)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.SetBytes(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.SetBytes(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item))
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
@@ -160,16 +161,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.ReasoningIndex = idx
st.ReasoningBuf.Reset()
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningItemID)
out = append(out, emitEvent("response.output_item.added", item))
// add a summary part placeholder
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID)
part, _ = sjson.Set(part, "output_index", idx)
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", st.ReasoningItemID)
part, _ = sjson.SetBytes(part, "output_index", idx)
out = append(out, emitEvent("response.reasoning_summary_part.added", part))
st.ReasoningPartAdded = true
}
@@ -181,10 +182,10 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
dt := d.Get("type").String()
if dt == "text_delta" {
if t := d.Get("text"); t.Exists() {
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.Set(msg, "delta", t.String())
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
@@ -196,22 +197,22 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.FuncArgsBuf[idx] = &strings.Builder{}
}
st.FuncArgsBuf[idx].WriteString(pj.String())
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.Set(msg, "output_index", idx)
msg, _ = sjson.Set(msg, "delta", pj.String())
msg := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.SetBytes(msg, "output_index", idx)
msg, _ = sjson.SetBytes(msg, "delta", pj.String())
out = append(out, emitEvent("response.function_call_arguments.delta", msg))
}
} else if dt == "thinking_delta" {
if st.ReasoningActive {
if t := d.Get("thinking"); t.Exists() {
st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", t.String())
msg := []byte(`{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.SetBytes(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
}
}
@@ -219,17 +220,17 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
@@ -239,34 +240,34 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
args = buf.String()
}
}
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", args)
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.SetBytes(fcDone, "output_index", idx)
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.InFuncBlock = false
} else if st.ReasoningActive {
full := st.ReasoningBuf.String()
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", full)
textDone := []byte(`{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`)
textDone, _ = sjson.SetBytes(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.SetBytes(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.SetBytes(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.SetBytes(textDone, "text", full)
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", full)
partDone := []byte(`{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.SetBytes(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.SetBytes(partDone, "part.text", full)
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
st.ReasoningActive = false
st.ReasoningPartAdded = false
@@ -284,92 +285,92 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
case "message_stop":
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String())
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool())
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value())
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value())
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String())
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value())
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
// Build response.output from aggregated state
outputsWrapper := `{"arr":[]}`
outputsWrapper := []byte(`{"arr":[]}`)
// reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", st.ReasoningItemID)
item, _ = sjson.SetBytes(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
// assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", st.CurrentMsgID)
item, _ = sjson.SetBytes(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
// function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 {
@@ -396,16 +397,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
reasoningTokens := int64(0)
@@ -414,15 +415,15 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
usagePresent := st.UsageSeen || reasoningTokens > 0
if usagePresent {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.OutputTokens)
if reasoningTokens > 0 {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
total := st.InputTokens + st.OutputTokens
if total > 0 || st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
}
out = append(out, emitEvent("response.completed", completed))
@@ -432,7 +433,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
// We follow the same aggregation logic as the streaming variant but produce
// one final object matching docs/out.json structure.
@@ -455,7 +456,7 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Base OpenAI Responses (non-stream) object
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`
out := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`)
// Aggregation state
var (
@@ -557,88 +558,88 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Populate base fields
out, _ = sjson.Set(out, "id", responseID)
out, _ = sjson.Set(out, "created_at", createdAt)
out, _ = sjson.SetBytes(out, "id", responseID)
out, _ = sjson.SetBytes(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String())
out, _ = sjson.SetBytes(out, "instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
out, _ = sjson.Set(out, "max_output_tokens", v.Int())
out, _ = sjson.SetBytes(out, "max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "max_tool_calls", v.Int())
out, _ = sjson.SetBytes(out, "max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
out, _ = sjson.Set(out, "model", v.String())
out, _ = sjson.SetBytes(out, "model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool())
out, _ = sjson.SetBytes(out, "parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
out, _ = sjson.Set(out, "previous_response_id", v.String())
out, _ = sjson.SetBytes(out, "previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
out, _ = sjson.Set(out, "prompt_cache_key", v.String())
out, _ = sjson.SetBytes(out, "prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
out, _ = sjson.Set(out, "reasoning", v.Value())
out, _ = sjson.SetBytes(out, "reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
out, _ = sjson.Set(out, "safety_identifier", v.String())
out, _ = sjson.SetBytes(out, "safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
out, _ = sjson.Set(out, "service_tier", v.String())
out, _ = sjson.SetBytes(out, "service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
out, _ = sjson.Set(out, "store", v.Bool())
out, _ = sjson.SetBytes(out, "store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
out, _ = sjson.Set(out, "temperature", v.Float())
out, _ = sjson.SetBytes(out, "temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
out, _ = sjson.Set(out, "text", v.Value())
out, _ = sjson.SetBytes(out, "text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
out, _ = sjson.Set(out, "tool_choice", v.Value())
out, _ = sjson.SetBytes(out, "tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
out, _ = sjson.Set(out, "tools", v.Value())
out, _ = sjson.SetBytes(out, "tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
out, _ = sjson.Set(out, "top_logprobs", v.Int())
out, _ = sjson.SetBytes(out, "top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
out, _ = sjson.Set(out, "top_p", v.Float())
out, _ = sjson.SetBytes(out, "top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
out, _ = sjson.Set(out, "truncation", v.String())
out, _ = sjson.SetBytes(out, "truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
out, _ = sjson.Set(out, "user", v.Value())
out, _ = sjson.SetBytes(out, "user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
out, _ = sjson.Set(out, "metadata", v.Value())
out, _ = sjson.SetBytes(out, "metadata", v.Value())
}
}
// Build output array
outputsWrapper := `{"arr":[]}`
outputsWrapper := []byte(`{"arr":[]}`)
if reasoningBuf.Len() > 0 {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", reasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", reasoningItemID)
item, _ = sjson.SetBytes(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
if currentMsgID != "" || textBuf.Len() > 0 {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", currentMsgID)
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", currentMsgID)
item, _ = sjson.SetBytes(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
if len(toolCalls) > 0 {
// Preserve index order
@@ -659,28 +660,28 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" {
args = "{}"
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.id)
item, _ = sjson.Set(item, "name", st.name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", st.id)
item, _ = sjson.SetBytes(item, "name", st.name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRawBytes(out, "output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
// Usage
total := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", total)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", total)
if reasoningBuf.Len() > 0 {
// Rough estimate similar to chat completions
reasoningTokens := int64(len(reasoningBuf.String()) / 4)
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
}

View File

@@ -36,32 +36,41 @@ import (
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := `{"model":"","instructions":"","input":[]}`
template := []byte(`{"model":"","instructions":"","input":[]}`)
rootResult := gjson.ParseBytes(rawJSON)
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
// Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system")
if systemsResult.IsArray() {
systemResults := systemsResult.Array()
message := `{"type":"message","role":"developer","content":[]}`
if systemsResult.Exists() {
message := []byte(`{"type":"message","role":"developer","content":[]}`)
contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
systemTypeResult := systemResult.Get("type")
if systemTypeResult.String() == "text" {
text := systemResult.Get("text").String()
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue
appendSystemText := func(text string) {
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
return
}
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
}
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
template, _ = sjson.SetRawBytes(template, "input.-1", message)
}
}
@@ -74,9 +83,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
messageResult := messageResults[i]
messageRole := messageResult.Get("role").String()
newMessage := func() string {
msg := `{"type": "message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", messageRole)
newMessage := func() []byte {
msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", messageRole)
return msg
}
@@ -86,7 +95,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
flushMessage := func() {
if hasContent {
template, _ = sjson.SetRaw(template, "input.-1", message)
template, _ = sjson.SetRawBytes(template, "input.-1", message)
message = newMessage()
contentIndex = 0
hasContent = false
@@ -98,15 +107,15 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if messageRole == "assistant" {
partType = "output_text"
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
hasContent = true
}
appendImageContent := func(dataURL string) {
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
contentIndex++
hasContent = true
}
@@ -142,8 +151,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
case "tool_use":
flushMessage()
functionCallMessage := `{"type":"function_call"}`
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
functionCallMessage := []byte(`{"type":"function_call"}`)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
{
name := messageContentResult.Get("name").String()
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
@@ -152,19 +161,19 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "name", name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallMessage)
case "tool_result":
flushMessage()
functionCallOutputMessage := `{"type":"function_call_output"}`
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
contentResult := messageContentResult.Get("content")
if contentResult.IsArray() {
toolResultContentIndex := 0
toolResultContent := `[]`
toolResultContent := []byte(`[]`)
contentResults := contentResult.Array()
for k := 0; k < len(contentResults); k++ {
toolResultContentType := contentResults[k].Get("type").String()
@@ -185,27 +194,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContentIndex++
}
}
} else if toolResultContentType == "text" {
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContentIndex++
}
}
if toolResultContent != `[]` {
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
if toolResultContentIndex > 0 {
functionCallOutputMessage, _ = sjson.SetRawBytes(functionCallOutputMessage, "output", toolResultContent)
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallOutputMessage)
}
}
flushMessage()
@@ -220,8 +229,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Convert tools declarations to the expected format for the Codex API.
toolsResult := rootResult.Get("tools")
if toolsResult.IsArray() {
template, _ = sjson.SetRaw(template, "tools", `[]`)
template, _ = sjson.Set(template, "tool_choice", `auto`)
template, _ = sjson.SetRawBytes(template, "tools", []byte(`[]`))
template, _ = sjson.SetBytes(template, "tool_choice", `auto`)
toolResults := toolsResult.Array()
// Build short name map from declared tools
var names []string
@@ -237,11 +246,11 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Special handling: map Claude web search tool to Codex web_search
if toolResult.Get("type").String() == "web_search_20250305" {
// Replace the tool content entirely with {"type":"web_search"}
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`)
template, _ = sjson.SetRawBytes(template, "tools.-1", []byte(`{"type":"web_search"}`))
continue
}
tool := toolResult.Raw
tool, _ = sjson.Set(tool, "type", "function")
tool := []byte(toolResult.Raw)
tool, _ = sjson.SetBytes(tool, "type", "function")
// Apply shortened name if needed
if v := toolResult.Get("name"); v.Exists() {
name := v.String()
@@ -250,20 +259,26 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
tool, _ = sjson.SetBytes(tool, "name", name)
}
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
tool, _ = sjson.Delete(tool, "input_schema")
tool, _ = sjson.Delete(tool, "parameters.$schema")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
tool, _ = sjson.Set(tool, "strict", false)
template, _ = sjson.SetRaw(template, "tools.-1", tool)
tool, _ = sjson.SetRawBytes(tool, "parameters", []byte(normalizeToolParameters(toolResult.Get("input_schema").Raw)))
tool, _ = sjson.DeleteBytes(tool, "input_schema")
tool, _ = sjson.DeleteBytes(tool, "parameters.$schema")
tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.SetBytes(tool, "strict", false)
template, _ = sjson.SetRawBytes(template, "tools.-1", tool)
}
}
// Default to parallel tool calls unless tool_choice explicitly disables them.
parallelToolCalls := true
if disableParallelToolUse := rootResult.Get("tool_choice.disable_parallel_tool_use"); disableParallelToolUse.Exists() {
parallelToolCalls = !disableParallelToolUse.Bool()
}
// Add additional configuration parameters for the Codex API.
template, _ = sjson.Set(template, "parallel_tool_calls", true)
template, _ = sjson.SetBytes(template, "parallel_tool_calls", parallelToolCalls)
// Convert thinking.budget_tokens to reasoning.effort.
reasoningEffort := "medium"
@@ -294,13 +309,13 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
}
}
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.Set(template, "reasoning.summary", "auto")
template, _ = sjson.Set(template, "stream", true)
template, _ = sjson.Set(template, "store", false)
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
template, _ = sjson.SetBytes(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.SetBytes(template, "reasoning.summary", "auto")
template, _ = sjson.SetBytes(template, "stream", true)
template, _ = sjson.SetBytes(template, "store", false)
template, _ = sjson.SetBytes(template, "include", []string{"reasoning.encrypted_content"})
return []byte(template)
return template
}
// shortenNameIfNeeded applies a simple shortening rule for a single name.
@@ -403,15 +418,15 @@ func normalizeToolParameters(raw string) string {
if raw == "" || raw == "null" || !gjson.Valid(raw) {
return `{"type":"object","properties":{}}`
}
schema := raw
result := gjson.Parse(raw)
schema := []byte(raw)
schemaType := result.Get("type").String()
if schemaType == "" {
schema, _ = sjson.Set(schema, "type", "object")
schema, _ = sjson.SetBytes(schema, "type", "object")
schemaType = "object"
}
if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRaw(schema, "properties", `{}`)
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
}
return schema
return string(schema)
}

View File

@@ -0,0 +1,135 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}
func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantParallelToolCalls bool
}{
{
name: "Default to true when tool_choice.disable_parallel_tool_use is absent",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
{
name: "Disable parallel tool calls when client opts out",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": true},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: false,
},
{
name: "Keep parallel tool calls enabled when client explicitly allows them",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": false},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
if got := resultJSON.Get("parallel_tool_calls").Bool(); got != tt.wantParallelToolCalls {
t.Fatalf("parallel_tool_calls = %v, want %v. Output: %s", got, tt.wantParallelToolCalls, string(result))
}
})
}
}

View File

@@ -9,9 +9,9 @@ package claude
import (
"bytes"
"context"
"fmt"
"strings"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -43,8 +43,8 @@ type ConvertCodexResponseToClaudeParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
@@ -54,95 +54,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
output := ""
output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
template := ""
var template []byte
if typeStr == "response.created" {
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.SetBytes(template, "message.id", rootResult.Get("response.id").String())
output = "event: message_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
} else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
stopReason := rootResult.Get("response.stop_reason").String()
if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
} else if stopReason == "max_tokens" || stopReason == "stop" {
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
template, _ = sjson.SetBytes(template, "delta.stop_reason", stopReason)
} else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
template, _ = sjson.SetBytes(template, "delta.stop_reason", "end_turn")
}
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
template, _ = sjson.SetBytes(template, "usage.input_tokens", inputTokens)
template, _ = sjson.SetBytes(template, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
template, _ = sjson.SetBytes(template, "usage.cache_read_input_tokens", cachedTokens)
}
output = "event: message_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output += "event: message_stop\n"
output += `data: {"type":"message_stop"}`
output += "\n\n"
output = translatorcommon.AppendSSEEventBytes(output, "message_delta", template, 2)
output = translatorcommon.AppendSSEEventBytes(output, "message_stop", []byte(`{"type":"message_stop"}`), 2)
} else if typeStr == "response.output_item.added" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
@@ -150,37 +140,33 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
if orig, ok := rev[name]; ok {
name = orig
}
template, _ = sjson.Set(template, "content_block.name", name)
template, _ = sjson.SetBytes(template, "content_block.name", name)
}
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
}
} else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
@@ -189,17 +175,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
}
}
}
return []string{output}
return [][]byte{output}
}
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
@@ -214,28 +199,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string {
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
rootResult := gjson.ParseBytes(rawJSON)
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
responseData := rootResult.Get("response")
if !responseData.Exists() {
return ""
return []byte{}
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.SetBytes(out, "id", responseData.Get("id").String())
out, _ = sjson.SetBytes(out, "model", responseData.Get("model").String())
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
out, _ = sjson.SetBytes(out, "usage.cache_read_input_tokens", cachedTokens)
}
hasToolCall := false
@@ -276,9 +261,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
}
if thinkingBuilder.Len() > 0 {
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() {
@@ -287,9 +272,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" {
text := part.Get("text").String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
}
return true
@@ -297,9 +282,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else {
text := content.String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
}
}
@@ -310,9 +295,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
@@ -320,23 +305,23 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
inputRaw = argsJSON.Raw
}
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
}
return true
})
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
out, _ = sjson.SetBytes(out, "stop_reason", stopReason.String())
} else if hasToolCall {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
out, _ = sjson.SetBytes(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
out, _ = sjson.SetBytes(out, "stop_reason", "end_turn")
}
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
out, _ = sjson.SetRawBytes(out, "stop_sequence", []byte(stopSequence.Raw))
}
return out
@@ -386,6 +371,6 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -6,10 +6,9 @@ package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/sjson"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
)
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
@@ -24,14 +23,12 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
newOutputs := make([]string, 0)
newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
}
return newOutputs
}
@@ -47,15 +44,12 @@ func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, orig
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// log.Debug(string(rawJSON))
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
// - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
return translatorcommon.WrapGeminiCLIResponse(out)
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -38,7 +38,7 @@ import (
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base template
out := `{"model":"","instructions":"","input":[]}`
out := []byte(`{"model":"","instructions":"","input":[]}`)
root := gjson.ParseBytes(rawJSON)
@@ -82,24 +82,24 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// System instruction -> as a user message with input_text parts
sysParts := root.Get("system_instruction.parts")
if sysParts.IsArray() {
msg := `{"type":"message","role":"developer","content":[]}`
msg := []byte(`{"type":"message","role":"developer","content":[]}`)
arr := sysParts.Array()
for i := 0; i < len(arr); i++ {
p := arr[i]
if t := p.Get("text"); t.Exists() {
part := `{}`
part, _ = sjson.Set(part, "type", "input_text")
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
}
if len(gjson.Get(msg, "content").Array()) > 0 {
out, _ = sjson.SetRaw(out, "input.-1", msg)
if len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
}
}
@@ -123,23 +123,23 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
p := parr[j]
// text part
if t := p.Get("text"); t.Exists() {
msg := `{"type":"message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
partType := "input_text"
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
out, _ = sjson.SetRaw(out, "input.-1", msg)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
continue
}
// function call from model
if fc := p.Get("functionCall"); fc.Exists() {
fn := `{"type":"function_call"}`
fn := []byte(`{"type":"function_call"}`)
if name := fc.Get("name"); name.Exists() {
n := name.String()
if short, ok := shortMap[n]; ok {
@@ -147,31 +147,31 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
n = shortenNameIfNeeded(n)
}
fn, _ = sjson.Set(fn, "name", n)
fn, _ = sjson.SetBytes(fn, "name", n)
}
if args := fc.Get("args"); args.Exists() {
fn, _ = sjson.Set(fn, "arguments", args.Raw)
fn, _ = sjson.SetBytes(fn, "arguments", args.Raw)
}
// generate a paired random call_id and enqueue it so the
// corresponding functionResponse can pop the earliest id
// to preserve ordering when multiple calls are present.
id := genCallID()
fn, _ = sjson.Set(fn, "call_id", id)
fn, _ = sjson.SetBytes(fn, "call_id", id)
pendingCallIDs = append(pendingCallIDs, id)
out, _ = sjson.SetRaw(out, "input.-1", fn)
out, _ = sjson.SetRawBytes(out, "input.-1", fn)
continue
}
// function response from user
if fr := p.Get("functionResponse"); fr.Exists() {
fno := `{"type":"function_call_output"}`
fno := []byte(`{"type":"function_call_output"}`)
// Prefer a string result if present; otherwise embed the raw response as a string
if res := fr.Get("response.result"); res.Exists() {
fno, _ = sjson.Set(fno, "output", res.String())
fno, _ = sjson.SetBytes(fno, "output", res.String())
} else if resp := fr.Get("response"); resp.Exists() {
fno, _ = sjson.Set(fno, "output", resp.Raw)
fno, _ = sjson.SetBytes(fno, "output", resp.Raw)
}
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// fno, _ = sjson.SetBytes(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// attach the oldest queued call_id to pair the response
// with its call. If the queue is empty, generate a new id.
var id string
@@ -182,8 +182,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
id = genCallID()
}
fno, _ = sjson.Set(fno, "call_id", id)
out, _ = sjson.SetRaw(out, "input.-1", fno)
fno, _ = sjson.SetBytes(fno, "call_id", id)
out, _ = sjson.SetRawBytes(out, "input.-1", fno)
continue
}
}
@@ -193,8 +193,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Tools mapping: Gemini functionDeclarations -> Codex tools
tools := root.Get("tools")
if tools.IsArray() {
out, _ = sjson.SetRaw(out, "tools", `[]`)
out, _ = sjson.Set(out, "tool_choice", "auto")
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
out, _ = sjson.SetBytes(out, "tool_choice", "auto")
tarr := tools.Array()
for i := 0; i < len(tarr); i++ {
td := tarr[i]
@@ -205,8 +205,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
farr := fns.Array()
for j := 0; j < len(farr); j++ {
fn := farr[j]
tool := `{}`
tool, _ = sjson.Set(tool, "type", "function")
tool := []byte(`{}`)
tool, _ = sjson.SetBytes(tool, "type", "function")
if v := fn.Get("name"); v.Exists() {
name := v.String()
if short, ok := shortMap[name]; ok {
@@ -214,32 +214,32 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
tool, _ = sjson.SetBytes(tool, "name", name)
}
if v := fn.Get("description"); v.Exists() {
tool, _ = sjson.Set(tool, "description", v.String())
tool, _ = sjson.SetBytes(tool, "description", v.String())
}
if prm := fn.Get("parameters"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
cleaned := []byte(prm.Raw)
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
cleaned := []byte(prm.Raw)
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
}
tool, _ = sjson.Set(tool, "strict", false)
out, _ = sjson.SetRaw(out, "tools.-1", tool)
tool, _ = sjson.SetBytes(tool, "strict", false)
out, _ = sjson.SetRawBytes(out, "tools.-1", tool)
}
}
}
// Fixed flags aligning with Codex expectations
out, _ = sjson.Set(out, "parallel_tool_calls", true)
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
// Convert Gemini thinkingConfig to Codex reasoning.effort.
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if thinkingLevel.Exists() {
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if effort != "" {
out, _ = sjson.Set(out, "reasoning.effort", effort)
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true
}
} else {
@@ -263,7 +263,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
if thinkingBudget.Exists() {
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
out, _ = sjson.Set(out, "reasoning.effort", effort)
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true
}
}
@@ -272,22 +272,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
if !effortSet {
// No thinking config, set default effort
out, _ = sjson.Set(out, "reasoning.effort", "medium")
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "stream", true)
out, _ = sjson.Set(out, "store", false)
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.SetBytes(out, "stream", true)
out, _ = sjson.SetBytes(out, "store", false)
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
}
return []byte(out)
return out
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -7,9 +7,9 @@ package gemini
import (
"bytes"
"context"
"fmt"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -23,7 +23,7 @@ type ConvertCodexResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
LastStorageOutput []byte
}
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
@@ -38,19 +38,19 @@ type ConvertCodexResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToGeminiParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
LastStorageOutput: nil,
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -59,17 +59,17 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
typeStr := typeResult.String()
// Base Gemini response template
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" {
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
} else {
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
createdAtResult := rootResult.Get("response.created_at")
if createdAtResult.Exists() {
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
}
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
}
// Handle function call completion
@@ -78,7 +78,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
// Create function call part
functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
{
// Restore original tool name if shortened
n := itemResult.Get("name").String()
@@ -86,7 +86,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
}
// Parse and set arguments
@@ -94,47 +94,48 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
}
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// Use this return to storage message
return []string{}
return [][]byte{}
}
}
if typeStr == "response.created" { // Handle response creation - set model and response ID
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
part := `{"thought":true,"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"thought":true,"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
part := `{"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
} else {
return []string{}
return [][]byte{}
}
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" {
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template}
} else {
return []string{template}
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
return [][]byte{
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
template,
}
}
return [][]byte{template}
}
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
@@ -149,32 +150,32 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
// Base Gemini response template for non-streaming
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
// Set response metadata from the completed response
responseData := rootResult.Get("response")
if responseData.Exists() {
// Set response ID
if responseId := responseData.Get("id"); responseId.Exists() {
template, _ = sjson.Set(template, "responseId", responseId.String())
template, _ = sjson.SetBytes(template, "responseId", responseId.String())
}
// Set creation time
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
}
// Set usage metadata
@@ -183,14 +184,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
outputTokens := usage.Get("output_tokens").Int()
totalTokens := inputTokens + outputTokens
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
}
// Process output content to build parts array
hasToolCall := false
var pendingFunctionCalls []string
var pendingFunctionCalls [][]byte
flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) == 0 {
@@ -199,7 +200,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
}
@@ -215,9 +216,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content
if content := value.Get("content"); content.Exists() {
part := `{"text":"","thought":true}`
part, _ = sjson.Set(part, "text", content.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":"","thought":true}`)
part, _ = sjson.SetBytes(part, "text", content.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
}
case "message":
@@ -229,9 +230,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
}
}
return true
@@ -241,21 +242,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call":
// Collect function call for potential merging with consecutive ones
hasToolCall = true
functionCall := `{"functionCall":{"args":{},"name":""}}`
functionCall := []byte(`{"functionCall":{"args":{},"name":""}}`)
{
n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
}
// Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
}
}
@@ -270,9 +271,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Set finish reason based on whether there were tool calls
if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
} else {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
}
}
return template
@@ -307,6 +308,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -29,42 +29,42 @@ import (
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Start with empty JSON object
out := `{"instructions":""}`
out := []byte(`{"instructions":""}`)
// Stream must be set to true
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
// out, _ = sjson.Set(out, "temperature", v.Value())
// out, _ = sjson.SetBytes(out, "temperature", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
// out, _ = sjson.Set(out, "top_p", v.Value())
// out, _ = sjson.SetBytes(out, "top_p", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
// out, _ = sjson.Set(out, "top_k", v.Value())
// out, _ = sjson.SetBytes(out, "top_k", v.Value())
// }
// Map token limits
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// }
// Map reasoning effort
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
out, _ = sjson.SetBytes(out, "reasoning.effort", v.Value())
} else {
out, _ = sjson.Set(out, "reasoning.effort", "medium")
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "parallel_tool_calls", true)
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Build tool name shortening map from original tools (if any)
originalToolNameMap := map[string]string{}
@@ -100,9 +100,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// if m.Get("role").String() == "system" {
// c := m.Get("content")
// if c.Type == gjson.String {
// out, _ = sjson.Set(out, "instructions", c.String())
// out, _ = sjson.SetBytes(out, "instructions", c.String())
// } else if c.IsObject() && c.Get("type").String() == "text" {
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
// out, _ = sjson.SetBytes(out, "instructions", c.Get("text").String())
// }
// break
// }
@@ -110,7 +110,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// }
// Build input from messages, handling all message types including tool calls
out, _ = sjson.SetRaw(out, "input", `[]`)
out, _ = sjson.SetRawBytes(out, "input", []byte(`[]`))
if messages.IsArray() {
arr := messages.Array()
for i := 0; i < len(arr); i++ {
@@ -124,23 +124,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
content := m.Get("content").String()
// Create function_call_output object
funcOutput := `{}`
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.Set(funcOutput, "output", content)
out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
funcOutput := []byte(`{}`)
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default:
// Handle regular messages
msg := `{}`
msg, _ = sjson.Set(msg, "type", "message")
msg := []byte(`{}`)
msg, _ = sjson.SetBytes(msg, "type", "message")
if role == "system" {
msg, _ = sjson.Set(msg, "role", "developer")
msg, _ = sjson.SetBytes(msg, "role", "developer")
} else {
msg, _ = sjson.Set(msg, "role", role)
msg, _ = sjson.SetBytes(msg, "role", role)
}
msg, _ = sjson.SetRaw(msg, "content", `[]`)
msg, _ = sjson.SetRawBytes(msg, "content", []byte(`[]`))
// Handle regular content
c := m.Get("content")
@@ -150,10 +150,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", c.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", c.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if c.Exists() && c.IsArray() {
items := c.Array()
for j := 0; j < len(items); j++ {
@@ -165,39 +165,44 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", it.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", it.Get("text").String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
case "image_url":
// Map image inputs to input_image for Responses API
if role == "user" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_image")
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_image")
if u := it.Get("image_url.url"); u.Exists() {
part, _ = sjson.Set(part, "image_url", u.String())
part, _ = sjson.SetBytes(part, "image_url", u.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
case "file":
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_file")
part, _ = sjson.SetBytes(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
part, _ = sjson.SetBytes(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
}
}
}
}
out, _ = sjson.SetRaw(out, "input.-1", msg)
// Don't emit empty assistant messages when only tool_calls
// are present — Responses API needs function_call items
// directly, otherwise call_id matching fails (#2132).
if role != "assistant" || len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
}
// Handle tool calls for assistant messages as separate top-level objects
if role == "assistant" {
@@ -208,9 +213,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
tc := toolCallsArr[j]
if tc.Get("type").String() == "function" {
// Create function_call as top-level object
funcCall := `{}`
funcCall, _ = sjson.Set(funcCall, "type", "function_call")
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
funcCall := []byte(`{}`)
funcCall, _ = sjson.SetBytes(funcCall, "type", "function_call")
funcCall, _ = sjson.SetBytes(funcCall, "call_id", tc.Get("id").String())
{
name := tc.Get("function.name").String()
if short, ok := originalToolNameMap[name]; ok {
@@ -218,10 +223,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else {
name = shortenNameIfNeeded(name)
}
funcCall, _ = sjson.Set(funcCall, "name", name)
funcCall, _ = sjson.SetBytes(funcCall, "name", name)
}
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRaw(out, "input.-1", funcCall)
funcCall, _ = sjson.SetBytes(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRawBytes(out, "input.-1", funcCall)
}
}
}
@@ -235,26 +240,26 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
text := gjson.GetBytes(rawJSON, "text")
if rf.Exists() {
// Always create text object when response_format provided
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
}
rft := rf.Get("type").String()
switch rft {
case "text":
out, _ = sjson.Set(out, "text.format.type", "text")
out, _ = sjson.SetBytes(out, "text.format.type", "text")
case "json_schema":
js := rf.Get("json_schema")
if js.Exists() {
out, _ = sjson.Set(out, "text.format.type", "json_schema")
out, _ = sjson.SetBytes(out, "text.format.type", "json_schema")
if v := js.Get("name"); v.Exists() {
out, _ = sjson.Set(out, "text.format.name", v.Value())
out, _ = sjson.SetBytes(out, "text.format.name", v.Value())
}
if v := js.Get("strict"); v.Exists() {
out, _ = sjson.Set(out, "text.format.strict", v.Value())
out, _ = sjson.SetBytes(out, "text.format.strict", v.Value())
}
if v := js.Get("schema"); v.Exists() {
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
out, _ = sjson.SetRawBytes(out, "text.format.schema", []byte(v.Raw))
}
}
}
@@ -262,23 +267,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Map verbosity if provided
if text.Exists() {
if v := text.Get("verbosity"); v.Exists() {
out, _ = sjson.Set(out, "text.verbosity", v.Value())
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
}
}
} else if text.Exists() {
// If only text.verbosity present (no response_format), map verbosity
if v := text.Get("verbosity"); v.Exists() {
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
}
out, _ = sjson.Set(out, "text.verbosity", v.Value())
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
}
}
// Map tools (flatten function fields)
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", `[]`)
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
@@ -286,13 +291,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
out, _ = sjson.SetRawBytes(out, "tools.-1", []byte(t.Raw))
continue
}
if toolType == "function" {
item := `{}`
item, _ = sjson.Set(item, "type", "function")
item := []byte(`{}`)
item, _ = sjson.SetBytes(item, "type", "function")
fn := t.Get("function")
if fn.Exists() {
if v := fn.Get("name"); v.Exists() {
@@ -302,19 +307,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else {
name = shortenNameIfNeeded(name)
}
item, _ = sjson.Set(item, "name", name)
item, _ = sjson.SetBytes(item, "name", name)
}
if v := fn.Get("description"); v.Exists() {
item, _ = sjson.Set(item, "description", v.Value())
item, _ = sjson.SetBytes(item, "description", v.Value())
}
if v := fn.Get("parameters"); v.Exists() {
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
item, _ = sjson.SetRawBytes(item, "parameters", []byte(v.Raw))
}
if v := fn.Get("strict"); v.Exists() {
item, _ = sjson.Set(item, "strict", v.Value())
item, _ = sjson.SetBytes(item, "strict", v.Value())
}
}
out, _ = sjson.SetRaw(out, "tools.-1", item)
out, _ = sjson.SetRawBytes(out, "tools.-1", item)
}
}
}
@@ -325,7 +330,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch {
case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String())
out, _ = sjson.SetBytes(out, "tool_choice", tc.String())
case tc.IsObject():
tcType := tc.Get("type").String()
if tcType == "function" {
@@ -337,21 +342,21 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
name = shortenNameIfNeeded(name)
}
}
choice := `{}`
choice, _ = sjson.Set(choice, "type", "function")
choice := []byte(`{}`)
choice, _ = sjson.SetBytes(choice, "type", "function")
if name != "" {
choice, _ = sjson.Set(choice, "name", name)
choice, _ = sjson.SetBytes(choice, "name", name)
}
out, _ = sjson.SetRaw(out, "tool_choice", choice)
out, _ = sjson.SetRawBytes(out, "tool_choice", choice)
} else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(tc.Raw))
}
}
}
out, _ = sjson.Set(out, "store", false)
return []byte(out)
out, _ = sjson.SetBytes(out, "store", false)
return out
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -0,0 +1,635 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
// Expects developer msg + user msg + function_call + function_call_output.
// No empty assistant message should appear between user and function_call.
func TestToolCallSimple(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "sunny, 22C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
// system -> developer
if items[0].Get("type").String() != "message" {
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
}
if items[0].Get("role").String() != "developer" {
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
}
// user
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "user" {
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
}
// function_call, not an empty assistant msg
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_1" {
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
}
if items[2].Get("name").String() != "get_weather" {
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
}
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
}
// function_call_output
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_1" {
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
}
if items[3].Get("output").String() != "sunny, 22C" {
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
}
}
// Assistant has both text content and tool_calls — the message should
// be emitted (non-empty content), followed by function_call items.
func TestToolCallWithContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "What is the weather?"},
{
"role": "assistant",
"content": "Let me check the weather for you.",
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc",
"content": "rainy, 15C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + assistant(with content) + function_call + function_call_output
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
// assistant with content — should be kept
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "assistant" {
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
}
contentParts := items[1].Get("content").Array()
if len(contentParts) == 0 {
t.Errorf("item 1: assistant message should have content parts")
}
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_abc" {
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
}
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_abc" {
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_paris",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
},
{
"id": "call_london",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"London\"}"
}
},
{
"id": "call_tokyo",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Tokyo\"}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + 3 function_call + 3 function_call_output = 7
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
for i, expectedID := range expectedCallIDs {
idx := i + 1
if items[idx].Get("type").String() != "function_call" {
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedID {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
}
}
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
for i, expectedOutput := range expectedOutputs {
idx := i + 4
if items[idx].Get("type").String() != "function_call_output" {
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
}
if items[idx].Get("output").String() != expectedOutput {
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
}
}
}
// Regression test for #2132: tool-call-only assistant messages (content:null)
// must not produce an empty message item in the translated output.
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Call a tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_x",
"type": "function",
"function": {"name": "do_thing", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
],
"tools": [
{
"type": "function",
"function": {
"name": "do_thing",
"description": "Do a thing",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
typ := item.Get("type").String()
role := item.Get("role").String()
if typ == "message" && role == "assistant" {
contentArr := item.Get("content").Array()
if len(contentArr) == 0 {
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
}
}
}
// should be exactly: user + function_call + function_call_output
if len(items) != 3 {
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected user message")
}
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
}
// Two rounds of tool calling in one conversation, with a text reply in between.
func TestMultiTurnToolCalling(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
{"role": "assistant", "content": "It is sunny in Paris."},
{"role": "user", "content": "And London?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: unexpected empty assistant message", i)
}
}
}
// round 1
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[1].Get("call_id").String() != "call_r1" {
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
// text reply between rounds
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
}
// round 2
if items[5].Get("type").String() != "function_call" {
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
}
if items[5].Get("call_id").String() != "call_r2" {
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
}
if items[6].Get("type").String() != "function_call_output" {
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
}
}
// Tool names over 64 chars get shortened, call_id stays the same.
func TestToolNameShortening(t *testing.T) {
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
if len(longName) <= 64 {
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
}
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do it"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_long",
"type": "function",
"function": {
"name": "` + longName + `",
"arguments": "{}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
],
"tools": [
{
"type": "function",
"function": {
"name": "` + longName + `",
"description": "A tool with a very long name",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// find function_call
var funcCallItem gjson.Result
for _, item := range items {
if item.Get("type").String() == "function_call" {
funcCallItem = item
break
}
}
if !funcCallItem.Exists() {
t.Fatal("no function_call item found in output")
}
// call_id unchanged
if funcCallItem.Get("call_id").String() != "call_long" {
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
}
// name must be truncated
translatedName := funcCallItem.Get("name").String()
if translatedName == longName {
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
}
if len(translatedName) > 64 {
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
}
}
// content:"" (empty string, not null) should be treated the same as null.
func TestEmptyStringContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_empty",
"type": "function",
"function": {"name": "action", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
],
"tools": [
{
"type": "function",
"function": {
"name": "action",
"description": "An action",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: empty assistant message from content:\"\"", i)
}
}
}
// user + function_call + function_call_output
if len(items) != 3 {
t.Errorf("expected 3 input items, got %d", len(items))
}
}
// Every function_call_output must have a matching function_call by call_id.
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Multi-tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
]
},
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
],
"tools": [
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// collect call_ids from function_call items
callIDs := make(map[string]bool)
for _, item := range items {
if item.Get("type").String() == "function_call" {
callIDs[item.Get("call_id").String()] = true
}
}
for i, item := range items {
if item.Get("type").String() == "function_call_output" {
outID := item.Get("call_id").String()
if !callIDs[outID] {
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
}
}
}
// 2 calls, 2 outputs
funcCallCount := 0
funcOutputCount := 0
for _, item := range items {
switch item.Get("type").String() {
case "function_call":
funcCallCount++
case "function_call_output":
funcOutputCount++
}
}
if funcCallCount != 2 {
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
}
if funcOutputCount != 2 {
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
}
}
// Tools array should carry over to the Responses format output.
func TestToolsDefinitionTranslated(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hi"}
],
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
tools := gjson.Get(result, "tools").Array()
if len(tools) == 0 {
t.Fatal("no tools found in output")
}
found := false
for _, tool := range tools {
if tool.Get("name").String() == "search" {
found = true
break
}
}
if !found {
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
}
}

View File

@@ -41,8 +41,8 @@ type ConvertCliToOpenAIParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
@@ -55,12 +55,12 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
rootResult := gjson.ParseBytes(rawJSON)
@@ -70,67 +70,67 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
return []string{}
return [][]byte{}
}
// Extract and set the model version.
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
template, _ = sjson.SetBytes(template, "model", modelResult.String())
} else if cachedModel != "" {
template, _ = sjson.Set(template, "model", cachedModel)
template, _ = sjson.SetBytes(template, "model", cachedModel)
} else if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
}
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
// Extract and set the response ID.
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
if dataType == "response.reasoning_summary_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", deltaResult.String())
}
} else if dataType == "response.reasoning_summary_text.done" {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", "\n\n")
} else if dataType == "response.output_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String())
}
} else if dataType == "response.completed" {
finishReason := "stop"
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
finishReason = "tool_calls"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
return [][]byte{}
}
// Increment index for this new function call item.
@@ -138,9 +138,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
@@ -148,59 +148,59 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
return [][]byte{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
return [][]byte{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
return [][]byte{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
@@ -208,17 +208,17 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
return [][]byte{}
}
return []string{template}
return [][]byte{template}
}
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
@@ -233,53 +233,53 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
unixTimestamp := time.Now().Unix()
responseResult := rootResult.Get("response")
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version.
if modelResult := responseResult.Get("model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
template, _ = sjson.SetBytes(template, "model", modelResult.String())
}
// Extract and set the creation timestamp.
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
template, _ = sjson.Set(template, "created", createdAtResult.Int())
template, _ = sjson.SetBytes(template, "created", createdAtResult.Int())
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
template, _ = sjson.SetBytes(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if idResult := responseResult.Get("id"); idResult.Exists() {
template, _ = sjson.Set(template, "id", idResult.String())
template, _ = sjson.SetBytes(template, "id", idResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := responseResult.Get("usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
@@ -289,7 +289,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
outputArray := outputResult.Array()
var contentText string
var reasoningText string
var toolCalls []string
var toolCalls [][]byte
for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String()
@@ -319,10 +319,10 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
}
case "function_call":
// Handle function call content
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
functionCallTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", callIdResult.String())
}
if nameResult := outputItem.Get("name"); nameResult.Exists() {
@@ -331,11 +331,11 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if orig, ok := rev[n]; ok {
n = orig
}
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", n)
}
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", argsResult.String())
}
toolCalls = append(toolCalls, functionCallTemplate)
@@ -344,22 +344,22 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
// Set content and reasoning content if found
if contentText != "" {
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.content", contentText)
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
if reasoningText != "" {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
// Add tool calls if any
if len(toolCalls) > 0 {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls", []byte(`[]`))
for _, toolCall := range toolCalls {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls.-1", toolCall)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
}
@@ -367,8 +367,8 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String()
if status == "completed" {
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
}
}

View File

@@ -23,7 +23,7 @@ func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *test
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
@@ -40,7 +40,7 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}

View File

@@ -12,8 +12,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
inputResult := gjson.GetBytes(rawJSON, "input")
if inputResult.Type == gjson.String {
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
input, _ := sjson.SetBytes([]byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`), "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", input)
}
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)

View File

@@ -3,7 +3,6 @@ package responses
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
)
@@ -11,23 +10,25 @@ import (
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
out := fmt.Sprintf("data: %s", string(rawJSON))
return []string{out}
out := make([]byte, 0, len(rawJSON)+len("data: "))
out = append(out, []byte("data: ")...)
out = append(out, rawJSON...)
return [][]byte{out}
}
return []string{string(rawJSON)}
return [][]byte{rawJSON}
}
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response.
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
responseResult := rootResult.Get("response")
return responseResult.Raw
return []byte(responseResult.Raw)
}

View File

@@ -0,0 +1,67 @@
package common
import (
"strconv"
"github.com/tidwall/sjson"
)
func WrapGeminiCLIResponse(response []byte) []byte {
out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response)
if err != nil {
return response
}
return out
}
func GeminiTokenCountJSON(count int64) []byte {
out := make([]byte, 0, 96)
out = append(out, `{"totalTokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `,"promptTokensDetails":[{"modality":"TEXT","tokenCount":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `}]}`...)
return out
}
func ClaudeInputTokensJSON(count int64) []byte {
out := make([]byte, 0, 32)
out = append(out, `{"input_tokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, '}')
return out
}
func SSEEventData(event string, payload []byte) []byte {
out := make([]byte, 0, len(event)+len(payload)+14)
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
return out
}
func AppendSSEEventString(out []byte, event, payload string, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}
func AppendSSEEventBytes(out []byte, event string, payload []byte, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}

View File

@@ -6,10 +6,10 @@
package claude
import (
"bytes"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -36,33 +36,32 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
out := []byte(`{"model":"","request":{"contents":[]}}`)
out, _ = sjson.SetBytes(out, "model", modelName)
// system instruction
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
systemInstruction := []byte(`{"role":"user","parts":[]}`)
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
// contents
@@ -77,28 +76,28 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
role = "model"
}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentJSON := []byte(`{"role":"","parts":[]}`)
contentJSON, _ = sjson.SetBytes(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`)
part, _ = sjson.SetBytes(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
part, _ = sjson.SetBytes(part, "functionCall.name", functionName)
part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs))
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
}
case "tool_result":
@@ -112,10 +111,10 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "image":
source := contentResult.Get("source")
@@ -123,21 +122,21 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
mimeType := source.Get("media_type").String()
data := source.Get("data").String()
if mimeType != "" && data != "" {
part := `{"inlineData":{"mime_type":"","data":""}}`
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
part, _ = sjson.Set(part, "inlineData.data", data)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"inlineData":{"mime_type":"","data":""}}`)
part, _ = sjson.SetBytes(part, "inlineData.mime_type", mimeType)
part, _ = sjson.SetBytes(part, "inlineData.data", data)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
}
}
}
return true
})
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
}
return true
})
@@ -149,26 +148,27 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
tool, _ = sjson.DeleteBytes(tool, "strict")
tool, _ = sjson.DeleteBytes(tool, "input_examples")
tool, _ = sjson.DeleteBytes(tool, "type")
tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
out, _ = sjson.SetRawBytes(out, "request.tools", []byte(`[{"functionDeclarations":[]}]`))
hasTools = true
}
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
out, _ = sjson.SetRawBytes(out, "request.tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "request.tools")
out, _ = sjson.DeleteBytes(out, "request.tools")
}
}
@@ -186,15 +186,15 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
@@ -206,8 +206,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
@@ -219,25 +219,23 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
return outBytes
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
return out
}

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -47,8 +48,8 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
@@ -60,34 +61,33 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if bytes.Equal(rawJSON, []byte("[DONE]")) {
// Only send message_stop if we have actually output content
if (*param).(*Params).HasContent {
return []string{
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
return [][]byte{translatorcommon.AppendSSEEventString(nil, "message_stop", `{"type":"message_stop"}`, 3)}
}
return []string{}
return [][]byte{}
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
messageStartTemplate := []byte(`{"type":"message_start","message":{"id":"msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-20241022","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`)
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
appendEvent("message_start", string(messageStartTemplate))
(*param).(*Params).HasFirstResponse = true
}
@@ -110,9 +110,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).HasContent = true
} else {
// Transition from another state to thinking
@@ -123,19 +122,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).ResponseType = 2 // Set state to thinking
(*param).(*Params).HasContent = true
}
@@ -143,9 +137,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).HasContent = true
} else {
// Transition from another state to text content
@@ -156,19 +149,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).ResponseType = 1 // Set state to content
(*param).(*Params).HasContent = true
}
@@ -182,9 +170,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Handle state transitions when switching to function calls
// Close any existing function call block first
if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0
}
@@ -198,26 +184,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Close any other existing content block
if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex))
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
appendEvent("content_block_delta", string(data))
}
(*param).(*Params).ResponseType = 3
(*param).(*Params).HasContent = true
@@ -232,34 +213,28 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Only send final events if we have actually output content
if (*param).(*Params).HasContent {
// Close the final content block
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
// Send the final message delta with usage information and stop reason
output = output + "event: message_delta\n"
output = output + `data: `
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
// Create the message delta template with appropriate stop reason
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template := []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
// Set tool_use stop reason if tools were used in this response
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
} else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
}
// Include thinking tokens in output token count if present
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
template, _ = sjson.SetBytes(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.SetBytes(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
appendEvent("message_delta", string(template))
}
}
}
return []string{output}
return [][]byte{output}
}
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
@@ -271,21 +246,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Claude-compatible JSON response.
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.SetBytes(out, "id", root.Get("response.responseId").String())
out, _ = sjson.SetBytes(out, "model", root.Get("response.modelVersion").String())
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
parts := root.Get("response.candidates.0.content.parts")
textBuilder := strings.Builder{}
@@ -297,9 +272,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if textBuilder.Len() == 0 {
return
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
textBuilder.Reset()
}
@@ -307,9 +282,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if thinkingBuilder.Len() == 0 {
return
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
thinkingBuilder.Reset()
}
@@ -333,15 +308,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
continue
}
}
@@ -365,15 +340,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
}
}
}
out, _ = sjson.Set(out, "stop_reason", stopReason)
out, _ = sjson.SetBytes(out, "stop_reason", stopReason)
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
out, _ = sjson.DeleteBytes(out, "usage")
}
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -7,6 +7,7 @@ package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -33,23 +34,23 @@ import (
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template := []byte(`{"project":"","request":{},"model":""}`)
template, _ = sjson.SetRawBytes(template, "request", rawJSON)
template, _ = sjson.SetBytes(template, "model", gjson.GetBytes(template, "request.model").String())
template, _ = sjson.DeleteBytes(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
templateStr, errFixCLIToolResponse := fixCLIToolResponse(string(template))
if errFixCLIToolResponse != nil {
return []byte{}
}
template = []byte(templateStr)
systemInstructionResult := gjson.Get(template, "request.system_instruction")
systemInstructionResult := gjson.GetBytes(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
template, _ = sjson.SetRawBytes(template, "request.systemInstruction", []byte(systemInstructionResult.Raw))
template, _ = sjson.DeleteBytes(template, "request.system_instruction")
}
rawJSON = []byte(template)
rawJSON = template
// Normalize roles in request.contents: default to valid values if missing/invalid
contents := gjson.GetBytes(rawJSON, "request.contents")
@@ -110,12 +111,41 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
return true
})
// Filter out contents with empty parts to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
filteredContents := []byte(`[]`)
hasFiltered := false
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool {
parts := content.Get("parts")
if !parts.IsArray() || len(parts.Array()) == 0 {
hasFiltered = true
return true
}
filteredContents, _ = sjson.SetRawBytes(filteredContents, "-1", []byte(content.Raw))
return true
})
if hasFiltered {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", filteredContents)
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
// falling back to fallbackName if the original is empty.
func backfillFunctionResponseName(raw string, fallbackName string) string {
name := gjson.Get(raw, "functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
rawBytes, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
raw = string(rawBytes)
}
return raw
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -142,7 +172,7 @@ func fixCLIToolResponse(input string) (string, error) {
}
// Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}`
contentsWrapper := []byte(`{"contents":[]}`)
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -165,31 +195,28 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
@@ -198,25 +225,26 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group
if role == "model" {
functionCallsCount := 0
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsCount++
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if functionCallsCount > 0 {
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount,
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
@@ -225,7 +253,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
} else {
// Non-model content (user, etc.)
@@ -233,7 +261,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
return true
@@ -245,24 +273,25 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses {
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result := []byte(input)
result, _ = sjson.SetRawBytes(result, "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
return result, nil
return string(result), nil
}

View File

@@ -8,8 +8,8 @@ package gemini
import (
"bytes"
"context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -29,8 +29,8 @@ import (
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
// - [][]byte: The transformed request data in Gemini API format
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
@@ -43,22 +43,22 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
chunkTemplate := []byte(`[]`)
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
}
}
}
chunk = []byte(chunkTemplate)
chunk = chunkTemplate
}
return []string{string(chunk)}
return [][]byte{chunk}
}
return []string{}
return [][]byte{}
}
// ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
@@ -72,15 +72,15 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing the response data
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return responseResult.Raw
return []byte(responseResult.Raw)
}
return string(rawJSON)
return rawJSON
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -299,43 +299,43 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw
fnRaw := []byte(fn.Raw)
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
renamed, errRename := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
fnRaw = []byte(renamed)
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict")
if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", fnRaw)
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue

View File

@@ -41,8 +41,8 @@ var functionCallIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
@@ -51,15 +51,15 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
return [][]byte{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
@@ -68,14 +68,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
}
finishReason := ""
@@ -93,21 +93,21 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
}
@@ -145,33 +145,33 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
functionCallTemplate := []byte(`{"id":"","index":0,"type":"function","function":{"name":"","arguments":""}}`)
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
@@ -185,32 +185,32 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
if hasFunctionCall {
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "tool_calls")
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
// Only pass through specific finish reasons
if finishReason == "max_tokens" || finishReason == "stop" {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
}
}
return []string{template}
return [][]byte{template}
}
// ConvertCliResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
@@ -225,11 +225,11 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
return []byte{}
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/tidwall/gjson"
)
func ConvertGeminiCLIResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertGeminiCLIResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
@@ -15,7 +15,7 @@ func ConvertGeminiCLIResponseToOpenAIResponses(ctx context.Context, modelName st
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertGeminiCLIResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func ConvertGeminiCLIResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)

View File

@@ -33,30 +33,30 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
out := []byte(`{"contents":[]}`)
out, _ = sjson.SetBytes(out, "model", modelName)
// system instruction
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
systemInstruction := []byte(`{"role":"user","parts":[]}`)
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
out, _ = sjson.SetRawBytes(out, "system_instruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
out, _ = sjson.SetBytes(out, "system_instruction.parts.-1.text", systemResult.String())
}
// contents
@@ -71,17 +71,17 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
role = "model"
}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentJSON := []byte(`{"role":"","parts":[]}`)
contentJSON, _ = sjson.SetBytes(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
@@ -93,11 +93,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`)
part, _ = sjson.SetBytes(part, "thoughtSignature", geminiClaudeThoughtSignature)
part, _ = sjson.SetBytes(part, "functionCall.name", functionName)
part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs))
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
}
case "tool_result":
@@ -110,19 +110,34 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
funcName = toolCallID
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName)
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "image":
source := contentResult.Get("source")
if source.Get("type").String() != "base64" {
return true
}
mimeType := source.Get("media_type").String()
data := source.Get("data").String()
if mimeType == "" || data == "" {
return true
}
part := []byte(`{"inline_data":{"mime_type":"","data":""}}`)
part, _ = sjson.SetBytes(part, "inline_data.mime_type", mimeType)
part, _ = sjson.SetBytes(part, "inline_data.data", data)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
}
return true
})
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
out, _ = sjson.SetRawBytes(out, "contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
out, _ = sjson.SetRawBytes(out, "contents.-1", contentJSON)
}
return true
})
@@ -135,25 +150,33 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
tool := []byte(toolResult.Raw)
var err error
tool, err = sjson.DeleteBytes(tool, "input_schema")
if err != nil {
return true
}
tool, err = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
if err != nil {
return true
}
tool, _ = sjson.DeleteBytes(tool, "strict")
tool, _ = sjson.DeleteBytes(tool, "input_examples")
tool, _ = sjson.DeleteBytes(tool, "type")
tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[{"functionDeclarations":[]}]`))
hasTools = true
}
out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool)
out, _ = sjson.SetRawBytes(out, "tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "tools")
out, _ = sjson.DeleteBytes(out, "tools")
}
}
@@ -171,15 +194,15 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "AUTO")
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "NONE")
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
@@ -191,8 +214,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
@@ -204,32 +227,32 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", effort)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
maxBudget := 0
if mi := registry.LookupModelInfo(modelName, "gemini"); mi != nil && mi.Thinking != nil {
maxBudget = mi.Thinking.Max
}
if maxBudget > 0 {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget)
} else {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
}
}
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.temperature", v.Num)
out, _ = sjson.SetBytes(out, "generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.topP", v.Num)
out, _ = sjson.SetBytes(out, "generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.topK", v.Num)
out, _ = sjson.SetBytes(out, "generationConfig.topK", v.Num)
}
result := []byte(out)
result := out
result = common.AttachDefaultSafetySettings(result, "safetySettings")
return result

View File

@@ -40,3 +40,41 @@ func TestConvertClaudeRequestToGemini_ToolChoice_SpecificTool(t *testing.T) {
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
}
}
func TestConvertClaudeRequestToGemini_ImageContent(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-flash-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this image"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "aGVsbG8="
}
}
]
}
]
}`)
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
parts := gjson.GetBytes(output, "contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
if got := parts[0].Get("text").String(); got != "describe this image" {
t.Fatalf("Expected first part text 'describe this image', got '%s'", got)
}
if got := parts[1].Get("inline_data.mime_type").String(); got != "image/png" {
t.Fatalf("Expected image mime type 'image/png', got '%s'", got)
}
if got := parts[1].Get("inline_data.data").String(); got != "aGVsbG8=" {
t.Fatalf("Expected image data 'aGVsbG8=', got '%s'", got)
}
}

View File

@@ -13,6 +13,7 @@ import (
"strings"
"sync/atomic"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -47,8 +48,8 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - []string: A slice of strings, each containing a Claude-compatible JSON response.
func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of bytes, each containing a Claude-compatible SSE payload.
func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &Params{
IsGlAPIKey: false,
@@ -63,32 +64,31 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
if bytes.Equal(rawJSON, []byte("[DONE]")) {
// Only send message_stop if we have actually output content
if (*param).(*Params).HasContent {
return []string{
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
return [][]byte{translatorcommon.AppendSSEEventString(nil, "message_stop", `{"type":"message_stop"}`, 3)}
}
return []string{}
return [][]byte{}
}
output := ""
output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
messageStartTemplate := []byte(`{"type":"message_start","message":{"id":"msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-20241022","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`)
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
appendEvent("message_start", string(messageStartTemplate))
(*param).(*Params).HasFirstResponse = true
}
@@ -111,9 +111,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).HasContent = true
} else {
// Transition from another state to thinking
@@ -124,19 +123,14 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).ResponseType = 2 // Set state to thinking
(*param).(*Params).HasContent = true
}
@@ -144,9 +138,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Process regular text content (user-visible output)
// Continue existing text block
if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).HasContent = true
} else {
// Transition from another state to text content
@@ -157,19 +150,14 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
(*param).(*Params).ResponseType = 1 // Set state to content
(*param).(*Params).HasContent = true
}
@@ -185,9 +173,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// If we are already in tool use mode and name is empty, treat as continuation (delta).
if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" {
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
appendEvent("content_block_delta", string(data))
}
// Continue to next part without closing/opening logic
continue
@@ -196,9 +183,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Handle state transitions when switching to function calls
// Close any existing function call block first
if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0
}
@@ -212,26 +197,21 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Close any other existing content block
if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
(*param).(*Params).ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex))
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.SetBytes(data, "content_block.name", clientToolName)
appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
appendEvent("content_block_delta", string(data))
}
(*param).(*Params).ResponseType = 3
(*param).(*Params).HasContent = true
@@ -244,30 +224,25 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
// Only send final events if we have actually output content
if (*param).(*Params).HasContent {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template := []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
if (*param).(*Params).SawToolCall {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
} else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
template, _ = sjson.SetBytes(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.SetBytes(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
appendEvent("message_delta", string(template))
}
}
}
return []string{output}
return [][]byte{output}
}
// ConvertGeminiResponseToClaudeNonStream converts a non-streaming Gemini response to a non-streaming Claude response.
@@ -279,21 +254,21 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Claude-compatible JSON response.
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
out, _ = sjson.Set(out, "model", root.Get("modelVersion").String())
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.SetBytes(out, "id", root.Get("responseId").String())
out, _ = sjson.SetBytes(out, "model", root.Get("modelVersion").String())
inputTokens := root.Get("usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
parts := root.Get("candidates.0.content.parts")
textBuilder := strings.Builder{}
@@ -305,9 +280,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 {
return
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
textBuilder.Reset()
}
@@ -315,9 +290,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 {
return
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
thinkingBuilder.Reset()
}
@@ -342,15 +317,15 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
upstreamToolName := functionCall.Get("name").String()
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
continue
}
}
@@ -374,15 +349,15 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
}
}
}
out, _ = sjson.Set(out, "stop_reason", stopReason)
out, _ = sjson.SetBytes(out, "stop_reason", stopReason)
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
out, _ = sjson.DeleteBytes(out, "usage")
}
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -7,8 +7,8 @@ package geminiCLI
import (
"bytes"
"context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/sjson"
)
@@ -26,19 +26,18 @@ var dataTag = []byte("data:")
// - param: A pointer to a parameter object for the conversion (unused).
//
// Returns:
// - []string: A slice of strings, each containing a Gemini CLI-compatible JSON response.
func ConvertGeminiResponseToGeminiCLI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
// - [][]byte: A slice of Gemini CLI-compatible JSON responses.
func ConvertGeminiResponseToGeminiCLI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
return [][]byte{}
}
json := `{"response": {}}`
rawJSON, _ = sjson.SetRawBytes([]byte(json), "response", rawJSON)
return []string{string(rawJSON)}
rawJSON, _ = sjson.SetRawBytes([]byte(`{"response":{}}`), "response", rawJSON)
return [][]byte{rawJSON}
}
// ConvertGeminiResponseToGeminiCLINonStream converts a non-streaming Gemini response to a non-streaming Gemini CLI response.
@@ -50,13 +49,12 @@ func ConvertGeminiResponseToGeminiCLI(_ context.Context, _ string, originalReque
// - param: A pointer to a parameter object for the conversion (unused).
//
// Returns:
// - string: A Gemini CLI-compatible JSON response.
func ConvertGeminiResponseToGeminiCLINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
json := `{"response": {}}`
rawJSON, _ = sjson.SetRawBytes([]byte(json), "response", rawJSON)
return string(rawJSON)
// - []byte: A Gemini CLI-compatible JSON response.
func ConvertGeminiResponseToGeminiCLINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rawJSON, _ = sjson.SetRawBytes([]byte(`{"response":{}}`), "response", rawJSON)
return rawJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -5,9 +5,11 @@ package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -95,6 +97,71 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
out = []byte(strJson)
}
// Backfill empty functionResponse.name from the preceding functionCall.name.
// Amp may send function responses with empty names; the Gemini API rejects these.
out = backfillEmptyFunctionResponseNames(out)
out = common.AttachDefaultSafetySettings(out, "safetySettings")
return out
}
// backfillEmptyFunctionResponseNames walks the contents array and for each
// model turn containing functionCall parts, records the call names in order.
// For the immediately following user/function turn containing functionResponse
// parts, any empty name is replaced with the corresponding call name.
func backfillEmptyFunctionResponseNames(data []byte) []byte {
contents := gjson.GetBytes(data, "contents")
if !contents.Exists() {
return data
}
out := data
var pendingCallNames []string
contents.ForEach(func(contentIdx, content gjson.Result) bool {
role := content.Get("role").String()
// Collect functionCall names from model turns
if role == "model" {
var names []string
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
names = append(names, part.Get("functionCall.name").String())
}
return true
})
if len(names) > 0 {
pendingCallNames = names
} else {
pendingCallNames = nil
}
return true
}
// Backfill empty functionResponse names from pending call names
if len(pendingCallNames) > 0 {
ri := 0
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
name := part.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" {
if ri < len(pendingCallNames) {
out, _ = sjson.SetBytes(out,
fmt.Sprintf("contents.%d.parts.%d.functionResponse.name", contentIdx.Int(), partIdx.Int()),
pendingCallNames[ri])
} else {
log.Debugf("more function responses than calls at contents[%d], skipping name backfill", contentIdx.Int())
}
}
ri++
}
return true
})
pendingCallNames = nil
}
return true
})
return out
}

Some files were not shown because too many files have changed in this diff Show More