Compare commits

...

136 Commits

Author SHA1 Message Date
Luis Pater
7fa527193c Merge pull request #453 from HeCHieh/fix/github-copilot-gpt54-responses
Fix GitHub Copilot gpt-5.4 endpoint routing
2026-03-25 09:45:23 +08:00
Luis Pater
ed0eb51b4d Merge pull request #450 from lwiles692/feature/add-codebuddy-support
feat(auth): add CodeBuddy-CN browser OAuth authentication support
2026-03-25 09:43:52 +08:00
Luis Pater
0e4f669c8b Merge branch 'router-for-me:main' into main 2026-03-25 09:38:34 +08:00
Luis Pater
76c064c729 Merge pull request #2335 from router-for-me/auth
Support batch upload and delete for auth files
2026-03-25 09:34:44 +08:00
Luis Pater
d2f652f436 Merge pull request #2333 from router-for-me/codex
feat(codex): pass through codex client identity headers
2026-03-25 09:34:09 +08:00
Luis Pater
6a452a54d5 Merge pull request #2316 from router-for-me/openai
Add per-model thinking support for OpenAI compatibility
2026-03-25 09:31:28 +08:00
hkfires
9e5693e74f feat(api): support batch auth file upload and delete 2026-03-25 09:20:17 +08:00
hkfires
528b1a2307 feat(codex): pass through codex client identity headers 2026-03-25 08:48:18 +08:00
Luis Pater
0cc978ec1d Merge pull request #2297 from router-for-me/readme
docs(readme): update japanese documentation links
2026-03-25 03:11:24 +08:00
hkfires
fee736933b feat(openai-compat): add per-model thinking support 2026-03-24 14:21:12 +08:00
hkfires
5c99846ecf docs(readme): update japanese documentation links 2026-03-24 09:47:01 +08:00
Luis Pater
d475aaba96 Fixed: #2274
fix(translator): omit null content fields in Codex OpenAI tool call responses
2026-03-24 01:00:57 +08:00
Luis Pater
1dc4ecb1b8 Merge pull request #456 from router-for-me/plus
v6.9.1
2026-03-24 00:43:35 +08:00
Luis Pater
1315f710f5 Merge branch 'main' into plus 2026-03-24 00:43:26 +08:00
Luis Pater
96f55570f7 Merge pull request #2282 from eltociear/add-ja-doc
docs: add Japanese README
2026-03-24 00:40:58 +08:00
Luis Pater
0906aeca87 Merge pull request #2254 from clcc2019/main
refactor: streamline usage reporting by consolidating record publishi…
2026-03-24 00:39:31 +08:00
Luis Pater
97c0487add Merge pull request #2223 from cnrpman/fix/codex-responses-web-search-preview-compat
fix: normalize web_search_preview for codex responses
2026-03-24 00:25:37 +08:00
Luis Pater
a576088d5f Merge pull request #2222 from kaitranntt/kai/fix/758-openai-proxy-alternating-model-support
fix: fall back on model support errors during auth rotation
2026-03-24 00:03:28 +08:00
Luis Pater
66ff916838 Merge pull request #2220 from xulongwu4/main
fix: normalize model name in TranslateRequest fallback to prevent prefix leak
2026-03-23 23:56:15 +08:00
Luis Pater
7b0453074e Merge pull request #2219 from beck-8/fix/context-done-race
fix: avoid data race when watching request cancellation
2026-03-23 22:57:21 +08:00
Luis Pater
a000eb523d Merge pull request #2213 from TTTPOB/ua-fix
feat(claude): stabilize device fingerprint across mixed Claude Code and cloaked clients
2026-03-23 22:53:51 +08:00
Luis Pater
18a4fedc7f Merge pull request #2126 from ailuntz/fix/watcher-auth-cache-memory
perf(watcher): reduce auth cache memory
2026-03-23 22:47:34 +08:00
Luis Pater
5d6cdccda0 Merge pull request #2268 from sususu98/fix/sanitize-tool-names
fix(translator): sanitize tool names for Gemini function_declarations compatibility
2026-03-23 21:42:22 +08:00
Luis Pater
1b7f4ac3e1 Merge pull request #2252 from sususu98/fix/antigravity-empty-thought-text
fix(antigravity): always include text field in thought parts to prevent Google 500
2026-03-23 21:41:25 +08:00
Luis Pater
afc1a5b814 Fixed: #2281
refactor(claude): centralize usage token calculation logic and add tests for cached token handling
2026-03-23 21:30:03 +08:00
Ikko Eltociear Ashimine
7ed38db54f docs: update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:43 +09:00
Ikko Eltociear Ashimine
28c10f4e69 docs: update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:32 +09:00
Ikko Eltociear Ashimine
6e12441a3b Update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:19 +09:00
Ikko Ashimine
65c439c18d docs: add Japanese README 2026-03-23 15:23:18 +09:00
dslife2025
0ed2d16596 Merge branch 'router-for-me:main' into main 2026-03-23 09:50:43 +08:00
Supra4E8C
db335ac616 Merge pull request #2269 from router-for-me/auth-fix
fix(auth): ensure absolute paths for auth file handling
2026-03-22 22:53:44 +08:00
Luis Pater
f3c59165d7 Merge branch 'pr-454'
# Conflicts:
#	cmd/server/main.go
#	internal/translator/claude/openai/chat-completions/claude_openai_response.go
2026-03-22 22:52:46 +08:00
hechieh
e6690cb447 Refine GitHub Copilot endpoint selection
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:43:35 +08:00
hechieh
35907416b8 Fix GitHub Copilot gpt-5.4 endpoint routing
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:05:44 +08:00
sususu98
e8bb350467 fix: extend tool name sanitization to all remaining Gemini-bound translators
Apply SanitizeFunctionName on request and RestoreSanitizedToolName on
response for: gemini/claude, gemini/openai/chat-completions,
gemini/openai/responses, antigravity/openai/chat-completions,
gemini-cli/openai/chat-completions.

Also update SanitizedToolNameMap to handle OpenAI format
(tools[].function.name) in addition to Claude format (tools[].name).
2026-03-22 14:06:46 +08:00
Supra4E8C
5331d51f27 fix(auth): ensure absolute paths for auth file handling 2026-03-22 13:58:16 +08:00
sususu98
755ca75879 fix: address review feedback - init ToolNameMap eagerly, log collisions, add collision test 2026-03-22 13:24:03 +08:00
sususu98
2398ebad55 fix(translator): sanitize tool names for Gemini function_declarations compatibility
Claude Code and MCP clients may send tool names containing characters
invalid for Gemini's function_declarations (e.g. '/', '@', spaces).
Sanitize on request via SanitizeFunctionName and restore original names
on response for both antigravity/claude and gemini-cli/claude translators.
2026-03-22 13:10:53 +08:00
clcc2019
c1bf298216 refactor: streamline usage reporting by consolidating record publishing logic
- Introduced a new method `buildRecord` in `usageReporter` to encapsulate record creation, improving code readability and maintainability.
- Added latency tracking to usage records, ensuring accurate reporting of request latencies.
- Updated tests to validate the inclusion of latency in usage records and ensure proper functionality of the new reporting structure.
2026-03-20 19:44:26 +08:00
sususu
e005208d76 fix(antigravity): always include text field in thought parts to prevent Google 500
When Claude sends redacted thinking with empty text, the translator
was omitting the "text" field from thought parts. Google Antigravity
API requires this field, causing 500 "Unknown Error" responses.

Verified: 129/129 error logs with empty thought → 500, 0/97 success
logs had empty thought. After fix: 0 new "Unknown Error" 500s.
2026-03-20 18:59:25 +08:00
Junyi Du
d1df70d02f chore: add codex builtin tool normalization logging 2026-03-20 14:08:37 +08:00
Luis Pater
f81acd0760 Merge pull request #2243 from router-for-me/oauth
Improve OAuth callback handling with async prompts
2026-03-20 12:35:44 +08:00
hkfires
636da4c932 refactor(auth): replace manual input handling with AsyncPrompt for callback URLs 2026-03-20 12:24:27 +08:00
hkfires
cccb77b552 fix(auth): avoid blocking oauth callback wait on prompt 2026-03-20 11:48:30 +08:00
Luis Pater
2bd646ad70 refactor: replace sjson.Set usage with sjson.SetBytes to optimize mutable JSON transformations 2026-03-19 17:58:54 +08:00
tpob
52c1fa025e fix(claude): learn official fingerprints after custom baselines 2026-03-19 13:59:41 +08:00
tpob
680105f84d fix(claude): refresh cached fingerprint after baseline upgrades 2026-03-19 13:28:58 +08:00
tpob
f7069e9548 fix(claude): pin stabilized OS arch to baseline 2026-03-19 13:07:16 +08:00
lwiles692
7275e99b41 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:59 +08:00
lwiles692
c28b65f849 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:40 +08:00
Junyi Du
793840cdb4 fix: cover dated and nested codex web search aliases 2026-03-19 03:41:12 +08:00
Junyi Du
8f421de532 fix: handle sjson errors in codex tool normalization 2026-03-19 03:36:06 +08:00
Junyi Du
be2dd60ee7 fix: normalize web_search_preview for codex responses 2026-03-19 03:23:14 +08:00
Tam Nhu Tran
ea3e0b713e fix: harden pooled model-support fallback state 2026-03-18 13:19:20 -04:00
tpob
8179d5a8a4 fix(claude): avoid racy fingerprint downgrades 2026-03-19 01:03:41 +08:00
tpob
6fa7abe434 fix(claude): keep configured baseline above older fingerprints 2026-03-19 01:02:04 +08:00
Tam Nhu Tran
5135c22cd6 fix: fall back on model support errors during auth rotation 2026-03-18 12:43:45 -04:00
Longwu Ou
1e27990561 address PR review: log sjson error and add unit tests
- Log a warning instead of silently ignoring sjson.SetBytes errors in the TranslateRequest fallback path
  - Add registry_test.go with tests covering the fallback model normalization and verifying registered transforms take precedence
2026-03-18 12:43:40 -04:00
Longwu Ou
e1e9fc43c1 fix: normalize model name in TranslateRequest fallback to prevent prefix leak
When no request translator is registered for a format pair (e.g.
        openai-response → openai-response), TranslateRequest returned the raw
        payload unchanged. This caused client-side model prefixes (e.g.
        "copilot/gpt-5-mini") to leak into upstream requests, resulting in
        "The requested model is not supported" errors from providers.

        The fallback path now updates the "model" field in the payload to
        match the resolved model name before returning.
2026-03-18 12:30:22 -04:00
beck-8
b2921518ac fix: avoid data race when watching request cancellation 2026-03-19 00:15:52 +08:00
tpob
dd64adbeeb fix(claude): preserve legacy user agent overrides 2026-03-19 00:03:09 +08:00
tpob
616d41c06a fix(claude): restore legacy runtime OS arch fallback 2026-03-19 00:01:50 +08:00
tpob
e0e337aeb9 feat(claude): add switch for device profile stabilization 2026-03-18 19:31:59 +08:00
tpob
d52839fced fix: stabilize claude device fingerprint 2026-03-18 18:46:54 +08:00
Wei Lee
4022e69651 feat(auth): add CodeBuddy-CN browser OAuth authentication support 2026-03-18 17:50:12 +08:00
Luis Pater
56073ded69 Merge pull request #2200 from sususu98/feat/local-model-flag
feat: add -local-model flag to skip remote model catalog fetching
2026-03-18 10:58:07 +08:00
sususu98
9738a53f49 feat: add -local-model flag to skip remote model catalog fetching
When enabled, the server uses only the embedded models.json loaded at
init() and skips registry.StartModelsUpdater(), disabling the initial
remote fetch and 3-hour periodic refresh. The management panel
auto-updater (managementasset.StartAutoUpdater) is unaffected.
2026-03-18 10:48:03 +08:00
Luis Pater
be3f8dbf7e Merge pull request #2187 from Darley-Wey/fix/claude-disable-parallel-tool-calls
fix(claude): honor disable_parallel_tool_use
2026-03-17 21:06:08 +08:00
Darley
9c6c3612a8 fix(claude): read disable_parallel_tool_use from tool_choice 2026-03-17 19:35:41 +08:00
Darley
19e1a4447a fix(claude): honor disable_parallel_tool_use 2026-03-17 19:17:41 +08:00
Luis Pater
7c2ad4cda2 Merge branch 'router-for-me:main' into main 2026-03-17 00:09:43 +08:00
Luis Pater
fb95813fbf Merge pull request #2142 from Muran-prog/fix/strip-uniqueItems-gemini-2123
fix: strip uniqueItems from Gemini function_declarations (#2123)
2026-03-16 20:34:28 +08:00
Luis Pater
db63f9b5d6 Merge pull request #2162 from enieuwy/fix/responses-api-json-valid-check
fix: validate JSON before raw-embedding function call outputs in Responses API
2026-03-16 18:42:31 +08:00
Luis Pater
25f6c4a250 Merge pull request #2158 from sususu98/fix/antigravity-functionresponse-name
fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
2026-03-16 18:39:40 +08:00
enieuwy
b24ae74216 fix: validate JSON before raw-embedding function call outputs in Responses API
gjson.Parse() marks any string starting with { or [ as gjson.JSON type,
even when the content is not valid JSON (e.g. macOS plist format, truncated
tool results). This caused sjson.SetRaw to embed non-JSON content directly
into the Gemini API request payload, producing 400 errors.

Add json.Valid() check before using SetRaw to ensure only actually valid
JSON is embedded raw. Non-JSON content now falls through to sjson.Set
which properly escapes it as a JSON string.

Fixes #2161
2026-03-16 15:29:18 +08:00
Luis Pater
59ad8f40dc Merge pull request #2124 from RGBadmin/feat/auth-list-priority-note
feat(api): expose priority and note in GET /auth-files response
2026-03-16 12:31:11 +08:00
sususu98
ff03dc6a2c fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
The Claude-to-Gemini translator derived function names by splitting
tool_use_id on "-", which produced empty strings for IDs with exactly
2 segments (e.g. toolu_tool-<uuid>). Replace the string-splitting
heuristic with a lookup map built from tool_use blocks during the
main processing loop, with fallback to the raw ID on miss.
2026-03-16 11:18:29 +08:00
Luis Pater
dc7187ca5b fix(websocket): pin only websocket-capable auth IDs and add corresponding test 2026-03-16 09:57:38 +08:00
Luis Pater
b1dcff778c Merge pull request #2141 from Muran-prog/fix/tool-calling-translation-2132
fix: skip empty assistant message in tool call translation (#2132)
2026-03-16 01:42:27 +08:00
Luis Pater
cef2aeeb08 Merge pull request #448 from router-for-me/plus
v6.8.54
2026-03-16 00:37:06 +08:00
Luis Pater
bcd1e8cc34 Merge branch 'main' into plus 2026-03-16 00:36:19 +08:00
Luis Pater
198b3f4a40 chore(ci): update build metadata to use GITHUB_REF_NAME in workflows 2026-03-16 00:30:44 +08:00
Luis Pater
9fee7f488e chore(ci): update GoReleaser config and release workflow to skip validation step 2026-03-16 00:16:25 +08:00
Luis Pater
1b46d39b8b Merge branch 'router-for-me:main' into main 2026-03-15 23:57:47 +08:00
RGBadmin
c1241a98e2 fix(api): restrict fallback note to string-typed JSON values
Only emit note in listAuthFilesFromDisk when the JSON value is actually
a string (gjson.String), matching the synthesizer/buildAuthFileEntry
behavior. Non-string values like numbers or booleans are now ignored
instead of being coerced via gjson.String().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:00:17 +08:00
RGBadmin
8d8f5970ee fix(api): fallback to Metadata for priority/note on uploaded auths
buildAuthFileEntry now falls back to reading priority/note from
auth.Metadata when Attributes lacks them. This covers auths registered
via UploadAuthFile which bypass the synthesizer and only populate
Metadata from the raw JSON.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 17:36:11 +08:00
RGBadmin
f90120f846 fix(api): propagate note to Gemini virtual auths and align priority parsing
- Read note from Attributes (consistent with priority) in buildAuthFileEntry,
  fixing missing note on Gemini multi-project virtual auth cards.
- Propagate note from primary to virtual auths in SynthesizeGeminiVirtualAuths,
  mirroring existing priority propagation.
- Sync note/priority writes to both Metadata and Attributes in PatchAuthFileFields,
  with refactored nil-check to reduce duplication (review feedback).
- Validate priority type in fallback disk-read path instead of coercing all values
  to 0 via gjson.Int(), aligning with the auth-manager code path.
- Add regression tests for note synthesis, virtual-auth note propagation, and
  end-to-end multi-project Gemini note inheritance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 16:47:01 +08:00
Muran-prog
0b94d36c4a test: use exact match for tool name assertion
Address review feedback - drop function.name fallback and
strings.Contains in favor of direct == comparison.
2026-03-14 21:45:28 +02:00
Muran-prog
152c310bb7 test: add uniqueItems stripping test
Covers the fix from the previous commit — verifies uniqueItems is
removed from the schema and moved to the description hint.
2026-03-14 21:22:14 +02:00
Muran-prog
f6bbca35ab fix: strip uniqueItems from Gemini function_declarations (#2123)
Gemini API rejects uniqueItems in tool schemas with 400. Add it to
unsupportedConstraints alongside minItems/maxItems where it belongs.

Same class of fix as #1424 and #1531.
2026-03-14 21:18:06 +02:00
Muran-prog
c8cee6a209 fix: skip empty assistant message in tool call translation (#2132)
When assistant has tool_calls but no text content, the translator
emitted an empty message into the Responses API input array before
function_call items. The API then couldn't match function_call_output
to its function_call by call_id, returning:

  No tool output found for function call ...

Only emit assistant messages that have content parts. Tool-call-only
messages now produce function_call items directly.

Added 9 tests for tool calling translation covering single/parallel
calls, multi-turn conversations, name shortening, empty content
edge cases, and call_id integrity.
2026-03-14 21:01:01 +02:00
Luis Pater
b5701f416b Fixed: #2102
fix(auth): ensure unique auth index for shared API keys across providers and credential identities
2026-03-15 02:48:54 +08:00
Luis Pater
4b1a404fcb Fixed: #1936
feat(translator): add image type handling in ConvertClaudeRequestToGemini
2026-03-15 02:18:28 +08:00
Luis Pater
b93cce5412 Merge pull request #444 from router-for-me/plus
v6.8.53
2026-03-15 01:50:42 +08:00
Luis Pater
c6cb24039d Merge branch 'main' into plus 2026-03-15 01:50:32 +08:00
Luis Pater
5382408489 Merge pull request #441 from GrothKeiran/fix/copilot-token-metadata
fix: persist copilot token metadata
2026-03-15 01:47:41 +08:00
Luis Pater
67669196ed Merge pull request #2131 from HEUDavid/docs/add-who-is-with-us
docs: Add Shadow AI to 'Who is with us?' section
2026-03-15 01:44:46 +08:00
hkfires
58fd9bf964 fix(codex): add 'go' plan_type in registerModelsForAuth 2026-03-14 22:09:14 +08:00
HEUDavid
7b3dfc67bc docs: Add Shadow AI to 'Who is with us?' section 2026-03-14 21:01:07 +08:00
HEUDavid
cdd24052d3 docs: Add Shadow AI to 'Who is with us?' section 2026-03-14 20:53:43 +08:00
Luis Pater
733fd8edab Merge pull request #2111 from qzydustin/main
Fix missing streaming usage tracking for OpenAI-compatible providers
2026-03-14 18:17:08 +08:00
Luis Pater
af27f2b8bc Merge pull request #2110 from router-for-me/codex
feat(service): extend model registration for team and business types
2026-03-14 18:10:01 +08:00
Luis Pater
2e1925d762 Merge pull request #2108 from sususu98/fix/gemini-cli-tool-schema-and-empty-parts
fix(gemini-cli): sanitize tool schemas and filter empty parts
2026-03-14 18:02:52 +08:00
Luis Pater
77254bd074 Merge pull request #2116 from router-for-me/vertex
fix(config): allow vertex keys without base-url
2026-03-14 17:27:48 +08:00
RGBadmin
5b6342e6ac feat(api): expose priority and note fields in GET /auth-files list response
The list endpoint previously omitted priority and note, which are stored
inside each auth file's JSON content. This adds them to both the normal
(auth-manager) and fallback (disk-read) code paths, and extends
PATCH /auth-files/fields to support writing the note field.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 14:47:31 +08:00
GrothKeiran
3960c93d51 refactor: derive copilot metadata from storage 2026-03-13 13:10:01 +00:00
GrothKeiran
339a81b650 fix: persist copilot token metadata 2026-03-13 12:54:43 +00:00
hkfires
560c020477 fix(config): allow vertex keys without base-url 2026-03-13 19:09:26 +08:00
Zhenyu Qi
aec65e3be3 fix(openai_compat): add stream_options.include_usage for streaming usage tracking 2026-03-13 00:48:17 -07:00
hkfires
f44f0702f8 feat(service): extend model registration for team and business types 2026-03-13 14:12:19 +08:00
sususu98
b76b79068f fix(gemini-cli): sanitize tool schemas and filter empty parts
1. Claude translator: add CleanJSONSchemaForGemini() to sanitize tool
   input schemas (removes $schema, anyOf, const, format, etc.) and
   delete eager_input_streaming from tool declarations. Remove fragile
   bytes.Replace for format:"uri" now covered by schema cleaner.

2. Gemini native translator: filter out content entries with empty or
   missing parts arrays to prevent Gemini API 400 error "required
   oneof field 'data' must have one initialized field".

Both fixes align gemini-cli with protections already present in the
antigravity translator.
2026-03-13 12:37:37 +08:00
Luis Pater
34c8ccb961 Fixed: #437
feat(runtime): strip `service_tier` in GitHub Copilot response normalization
2026-03-13 11:50:21 +08:00
Luis Pater
d08e164af3 chore(runtime): remove unused FetchAntigravityModels function from executor 2026-03-13 11:38:44 +08:00
Luis Pater
8178efaeda Merge pull request #439 from router-for-me/plus
v6.8.52
2026-03-13 11:30:25 +08:00
Luis Pater
86d5db472a Merge branch 'main' into plus 2026-03-13 11:28:52 +08:00
Luis Pater
020d36f6e8 Merge pull request #433 from LuxVTZ/feat/gitlab-duo-auth-plus
Add GitLab Duo management OAuth and PAT endpoints
2026-03-13 11:25:19 +08:00
Luis Pater
1db23979e8 Merge pull request #2106 from router-for-me/model
feat(model_registry): enhance model registration and refresh mechanisms
2026-03-13 11:18:51 +08:00
hkfires
c3d5dbe96f feat(model_registry): enhance model registration and refresh mechanisms 2026-03-13 10:56:39 +08:00
Luis Pater
5484489406 chore(ci): update model catalog fetch method in workflows 2026-03-12 11:19:24 +08:00
Luis Pater
0ac52da460 chore(ci): update model catalog fetch method in release workflow 2026-03-12 10:50:46 +08:00
Luis Pater
817cebb321 Merge pull request #2082 from router-for-me/antigravity
Refactor Antigravity model handling and improve logging
2026-03-12 10:39:13 +08:00
Luis Pater
683f3709d6 Merge pull request #2076 from aikins01/fix/backfill-empty-function-response-names
fix: backfill empty functionResponse.name from preceding functionCall
2026-03-12 10:35:44 +08:00
hkfires
dbd42a42b2 fix(model_updater): clarify log message for model refresh failure 2026-03-12 10:32:04 +08:00
hkfires
ec24baf757 feat(fetch_antigravity_models): add command to fetch and save Antigravity model list 2026-03-12 10:21:09 +08:00
hkfires
dea3e74d35 feat(antigravity): refactor model handling and remove unused code 2026-03-12 09:24:45 +08:00
Aikins Laryea
a6c3042e34 refactor: remove redundant bounds checks per code review 2026-03-12 00:12:43 +00:00
Aikins Laryea
861537c9bd fix: backfill empty functionResponse.name from preceding functionCall
when Amp or Claude Code sends functionResponse with an empty name in Gemini
conversation history, the Gemini API rejects the request with 400
"Name cannot be empty". this fix backfills empty names from the
corresponding preceding functionCall parts using positional matching.

covers all three Gemini translator paths:
- gemini/gemini (direct API key)
- antigravity/gemini (OAuth)
- gemini-cli/gemini (Gemini CLI)

also switches fixCLIToolResponse pending group matching from LIFO to
FIFO to correctly handle multiple sequential tool call groups.

fixes #1903
2026-03-12 00:00:38 +00:00
Luis Pater
8c92cb0883 Merge pull request #2056 from lang-911/codex/custom-useragent-request
feat(config/codex): Add Codex header defaults (`user-agent`: override; `beta-features`: default)
2026-03-11 22:56:36 +08:00
Luis Pater
89d7be9525 Merge branch 'dev' into codex/custom-useragent-request 2026-03-11 22:55:50 +08:00
lang-911
2b79d7f22f fix: restore double quotes style in config.example.yaml for consistency and readability 2026-03-11 06:59:26 -07:00
LuxVTZ
2bb686f594 Add GitLab Duo management OAuth and PAT endpoints 2026-03-11 17:58:34 +04:00
lang-911
163fe287ce fix: codex header defaults example 2026-03-11 06:55:03 -07:00
lang-911
70988d387b Add Codex websocket header defaults 2026-03-11 00:34:57 -07:00
Luis Pater
ddaa9d2436 Fixed: #2034
feat(proxy): centralize proxy handling with `proxyutil` package and enhance test coverage

- Added `proxyutil` package to simplify proxy handling across the codebase.
- Refactored various components (`executor`, `cliproxy`, `auth`, etc.) to use `proxyutil` for consistent and reusable proxy logic.
- Introduced support for "direct" proxy mode to explicitly bypass all proxies.
- Updated tests to validate proxy behavior (e.g., `direct`, HTTP/HTTPS, and SOCKS5).
- Enhanced YAML configuration documentation for proxy options.
2026-03-11 11:08:02 +08:00
Luis Pater
7b7b258c38 Fixed: #2022
test(translator): add tests for handling Claude system messages as string and array
2026-03-11 10:47:33 +08:00
ailuntz
c3762328a5 perf(watcher): reduce auth cache memory 2026-03-10 16:27:10 +08:00
164 changed files with 12369 additions and 4541 deletions

View File

@@ -17,7 +17,9 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog - name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Login to DockerHub - name: Login to DockerHub
@@ -27,7 +29,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata - name: Generate Build Metadata
run: | run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (amd64) - name: Build and push (amd64)
@@ -50,7 +52,9 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog - name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Login to DockerHub - name: Login to DockerHub
@@ -60,7 +64,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata - name: Generate Build Metadata
run: | run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (arm64) - name: Build and push (arm64)
@@ -94,7 +98,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata - name: Generate Build Metadata
run: | run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Create and push multi-arch manifests - name: Create and push multi-arch manifests

View File

@@ -13,7 +13,9 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog - name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:

View File

@@ -17,7 +17,9 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Refresh models catalog - name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- run: git fetch --force --tags - run: git fetch --force --tags
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
@@ -25,15 +27,14 @@ jobs:
cache: true cache: true
- name: Generate Build Metadata - name: Generate Build Metadata
run: | run: |
VERSION=$(git describe --tags --always --dirty) echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo "VERSION=${VERSION}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- uses: goreleaser/goreleaser-action@v4 - uses: goreleaser/goreleaser-action@v4
with: with:
distribution: goreleaser distribution: goreleaser
version: latest version: latest
args: release --clean args: release --clean --skip=validate
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VERSION: ${{ env.VERSION }} VERSION: ${{ env.VERSION }}

View File

@@ -1,3 +1,5 @@
version: 2
builds: builds:
- id: "cli-proxy-api-plus" - id: "cli-proxy-api-plus"
env: env:

View File

@@ -1,6 +1,6 @@
# CLIProxyAPI Plus # CLIProxyAPI Plus
[English](README.md) | 中文 [English](README.md) | 中文 | [日本語](README_JA.md)
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。 这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。

183
README_JA.md Normal file
View File

@@ -0,0 +1,183 @@
# CLI Proxy API
[English](README.md) | [中文](README_CN.md) | 日本語
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
OAuth経由でOpenAI CodexGPTモデルおよびClaude Codeもサポートしています。
ローカルまたはマルチアカウントのCLIアクセスを、OpenAIResponses含む/Gemini/Claude互換のクライアントやSDKで利用できます。
## スポンサー
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7およびGLM-5はProユーザーのみ利用可能モデルを10以上の人気AIコーディングツールClaude Code、Cline、Roo Codeなどで利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>PackyCodeのスポンサーシップに感謝しますPackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝しますAICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引がありますCLIProxyAPIユーザー向けの特別特典<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
</tbody>
</table>
## 概要
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
- OAuthログインによるOpenAI CodexサポートGPTモデル
- OAuthログインによるClaude Codeサポート
- OAuthログインによるQwen Codeサポート
- OAuthログインによるiFlowサポート
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
- ストリーミングおよび非ストリーミングレスポンス
- 関数呼び出し/ツールのサポート
- マルチモーダル入力サポート(テキストと画像)
- ラウンドロビン負荷分散による複数アカウント対応Gemini、OpenAI、Claude、QwenおよびiFlow
- シンプルなCLI認証フローGemini、OpenAI、Claude、QwenおよびiFlow
- Generative Language APIキーのサポート
- AI Studioビルドのマルチアカウント負荷分散
- Gemini CLIのマルチアカウント負荷分散
- Claude Codeのマルチアカウント負荷分散
- Qwen Codeのマルチアカウント負荷分散
- iFlowのマルチアカウント負荷分散
- OpenAI Codexのマルチアカウント負荷分散
- 設定によるOpenAI互換アップストリームプロバイダーOpenRouter
- プロキシ埋め込み用の再利用可能なGo SDK`docs/sdk-usage.md`を参照)
## はじめに
CLIProxyAPIガイド[https://help.router-for.me/](https://help.router-for.me/)
## 管理API
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
## Amp CLIサポート
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます
- Ampの APIパターン用のプロバイダールートエイリアス`/api/provider/{provider}/v1...`
- OAuth認証およびアカウント機能用の管理プロキシ
- 自動ルーティングによるスマートモデルフォールバック
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5``claude-sonnet-4`
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
## SDKドキュメント
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
- カスタムプロバイダーの例:`examples/custom-provider`
## コントリビューション
コントリビューションを歓迎しますお気軽にPull Requestを送ってください。
1. リポジトリをフォーク
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`
3. 変更をコミット(`git commit -m 'Add some amazing feature'`
4. ブランチにプッシュ(`git push origin feature/amazing-feature`
5. Pull Requestを作成
## 関連プロジェクト
CLIProxyAPIをベースにした以下のプロジェクトがあります
### [vibeproxy](https://github.com/automazeio/vibeproxy)
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデルGemini、Codex、Antigravityを即座に切り替えるCLIラッパー - APIキー不要
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
CLIProxyAPI管理用のmacOSネイティブGUIOAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
### [Quotio](https://github.com/nguyenphutrong/quotio)
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
### [CodMate](https://github.com/loocor/CodMate)
CLI AIセッションCodex、Claude Code、Gemini CLIを管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替えMain / Plus、自動ダウンロードおよび自動更新に対応
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LANローカルエリアネットワークを介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## その他の選択肢
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです
### [9Router](https://github.com/decolua/9router)
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換OpenAI/Claude/Gemini/Ollama、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツールCursor、Claude Code、Cline、RooCodeのサポートをゼロから構築 - APIキー不要
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイですスマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## ライセンス
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。

View File

@@ -0,0 +1,275 @@
// Command fetch_antigravity_models connects to the Antigravity API using the
// stored auth credentials and saves the dynamically fetched model list to a
// JSON file for inspection or offline use.
//
// Usage:
//
// go run ./cmd/fetch_antigravity_models [flags]
//
// Flags:
//
// --auths-dir <path> Directory containing auth JSON files (default: "auths")
// --output <path> Output JSON file path (default: "antigravity_models.json")
// --pretty Pretty-print the output JSON (default: true)
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
)
func init() {
logging.SetupBaseLogger()
log.SetLevel(log.InfoLevel)
}
// modelOutput wraps the fetched model list with fetch metadata.
type modelOutput struct {
Models []modelEntry `json:"models"`
}
// modelEntry contains only the fields we want to keep for static model definitions.
type modelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
Name string `json:"name"`
Description string `json:"description"`
ContextLength int `json:"context_length,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
func main() {
var authsDir string
var outputPath string
var pretty bool
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
flag.Parse()
// Resolve relative paths against the working directory.
wd, err := os.Getwd()
if err != nil {
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
os.Exit(1)
}
if !filepath.IsAbs(authsDir) {
authsDir = filepath.Join(wd, authsDir)
}
if !filepath.IsAbs(outputPath) {
outputPath = filepath.Join(wd, outputPath)
}
fmt.Printf("Scanning auth files in: %s\n", authsDir)
// Load all auth records from the directory.
fileStore := sdkauth.NewFileTokenStore()
fileStore.SetBaseDir(authsDir)
ctx := context.Background()
auths, err := fileStore.List(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
os.Exit(1)
}
if len(auths) == 0 {
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
os.Exit(1)
}
// Find the first enabled antigravity auth.
var chosen *coreauth.Auth
for _, a := range auths {
if a == nil || a.Disabled {
continue
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
chosen = a
break
}
}
if chosen == nil {
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
os.Exit(1)
}
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
// Fetch models from the upstream Antigravity API.
fmt.Println("Fetching Antigravity model list from upstream...")
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
models := fetchModels(fetchCtx, chosen)
if len(models) == 0 {
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
} else {
fmt.Printf("Fetched %d models.\n", len(models))
}
// Build the output payload.
out := modelOutput{
Models: models,
}
// Marshal to JSON.
var raw []byte
if pretty {
raw, err = json.MarshalIndent(out, "", " ")
} else {
raw, err = json.Marshal(out)
}
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
os.Exit(1)
}
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
os.Exit(1)
}
fmt.Printf("Model list saved to: %s\n", outputPath)
}
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
return nil
}
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
for _, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
if errReq != nil {
continue
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
httpClient.Transport = transport
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
continue
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
httpResp.Body.Close()
if errRead != nil {
continue
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
continue
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
continue
}
var models []modelEntry
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
// Skip internal/experimental models
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
entry := modelEntry{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
DisplayName: displayName,
Name: modelID,
Description: displayName,
}
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
entry.ContextLength = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
entry.MaxCompletionTokens = int(maxOut)
}
models = append(models, entry)
}
return models
}
return nil
}
func metaStringValue(m map[string]interface{}, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
switch val := v.(type) {
case string:
return val
default:
return ""
}
}

View File

@@ -95,6 +95,7 @@ func main() {
var kiroIDCRegion string var kiroIDCRegion string
var kiroIDCFlow string var kiroIDCFlow string
var githubCopilotLogin bool var githubCopilotLogin bool
var codeBuddyLogin bool
var projectID string var projectID string
var vertexImport string var vertexImport string
var configPath string var configPath string
@@ -103,6 +104,7 @@ func main() {
var standalone bool var standalone bool
var noIncognito bool var noIncognito bool
var useIncognito bool var useIncognito bool
var localModel bool
// Define command-line flags for different operation modes. // Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -131,12 +133,14 @@ func main() {
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)") flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device") flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "") flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI") flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server") flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
flag.CommandLine.Usage = func() { flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output() out := flag.CommandLine.Output()
@@ -514,6 +518,9 @@ func main() {
} else if githubCopilotLogin { } else if githubCopilotLogin {
// Handle GitHub Copilot login // Handle GitHub Copilot login
cmd.DoGitHubCopilotLogin(cfg, options) cmd.DoGitHubCopilotLogin(cfg, options)
} else if codeBuddyLogin {
// Handle CodeBuddy login
cmd.DoCodeBuddyLogin(cfg, options)
} else if codexLogin { } else if codexLogin {
// Handle Codex login // Handle Codex login
cmd.DoCodexLogin(cfg, options) cmd.DoCodexLogin(cfg, options)
@@ -578,11 +585,16 @@ func main() {
cmd.WaitForCloudDeploy() cmd.WaitForCloudDeploy()
return return
} }
if localModel && (!tuiMode || standalone) {
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
}
if tuiMode { if tuiMode {
if standalone { if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it. // Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background()) if !localModel {
registry.StartModelsUpdater(context.Background())
}
hook := tui.NewLogHook(2000) hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{}) hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook) log.AddHook(hook)
@@ -655,7 +667,9 @@ func main() {
} else { } else {
// Start the main proxy service // Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background()) if !localModel {
registry.StartModelsUpdater(context.Background())
}
if cfg.AuthDir != "" { if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg) kiro.InitializeAndStart(cfg.AuthDir, cfg)

View File

@@ -68,7 +68,8 @@ error-logs-max-files: 10
usage-statistics-enabled: false usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: '' # Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false force-model-prefix: false
@@ -115,6 +116,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "gemini-2.5-flash" # upstream model name # - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model # alias: "gemini-flash" # client alias mapped to the upstream model
@@ -133,6 +135,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "gpt-5-codex" # upstream model name # - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model # alias: "codex-latest" # client alias mapped to the upstream model
@@ -151,6 +154,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name # - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model # alias: "claude-sonnet-latest" # client alias mapped to the upstream model
@@ -171,12 +175,27 @@ nonstream-keepalive-interval: 0
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request # cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions. # Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers. # In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
# when the client omits them, while OS/arch remain runtime-derived. When
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
# while user-agent/package-version/runtime-version seed a software fingerprint that can
# still upgrade to newer official Claude client versions.
# claude-header-defaults: # claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)" # user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0" # package-version: "0.74.0"
# runtime-version: "v24.3.0" # runtime-version: "v24.3.0"
# os: "MacOS"
# arch: "arm64"
# timeout: "600" # timeout: "600"
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
# Default headers for Codex OAuth model requests.
# These are used only for file-backed/OAuth Codex requests when the client
# does not send the header. `user-agent` applies to HTTP and websocket requests;
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
# codex-header-defaults:
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
# beta-features: "multi_agent"
# Kiro (AWS CodeWhisperer) configuration # Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region # Note: Kiro API currently only operates in us-east-1 region
@@ -215,10 +234,13 @@ nonstream-keepalive-interval: 0
# api-key-entries: # api-key-entries:
# - api-key: "sk-or-v1-...b780" # - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url # - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider. # models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name. # - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API. # alias: "kimi-k2" # The alias used in the API.
# thinking: # optional: omit to default to levels ["low","medium","high"]
# levels: ["low", "medium", "high"]
# # You may repeat the same alias to build an internal model pool. # # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list. # # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below, # # Requests to that alias will round-robin across the upstream names below,
@@ -231,12 +253,13 @@ nonstream-keepalive-interval: 0
# - name: "kimi-k2.5" # - name: "kimi-k2.5"
# alias: "claude-opus-4.66" # alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL) # Vertex API keys (Vertex-compatible endpoints, base-url is optional)
# vertex-api-key: # vertex-api-key:
# - api-key: "vk-123..." # x-goog-api-key header # - api-key: "vk-123..." # x-goog-api-key header
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential # prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api # base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names # models: # optional: map aliases to upstream model names

View File

@@ -52,11 +52,11 @@ func init() {
sdktr.Register(fOpenAI, fMyProv, sdktr.Register(fOpenAI, fMyProv,
func(model string, raw []byte, stream bool) []byte { return raw }, func(model string, raw []byte, stream bool) []byte { return raw },
sdktr.ResponseTransform{ sdktr.ResponseTransform{
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string { Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
return []string{string(raw)} return [][]byte{raw}
}, },
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string { NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
return string(raw) return raw
}, },
}, },
) )

2
go.mod
View File

@@ -91,8 +91,8 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect

View File

@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@@ -14,13 +13,12 @@ import (
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
) )
const defaultAPICallTimeout = 60 * time.Second const defaultAPICallTimeout = 60 * time.Second
@@ -725,47 +723,12 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
} }
func buildProxyTransport(proxyStr string) *http.Transport { func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr) transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if proxyStr == "" { if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil return nil
} }
return transport
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
} }
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value). // headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).

View File

@@ -2,172 +2,112 @@ package management
import ( import (
"context" "context"
"encoding/json"
"io"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing" "testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
type memoryAuthStore struct { func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
mu sync.Mutex t.Parallel()
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) { h := &Handler{
_ = ctx cfg: &config.Config{
s.mu.Lock() SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
}, },
} }
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err) transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if httpTransport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
}
}
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel()
manager := coreauth.NewManager(nil, nil, nil)
geminiAuth := &coreauth.Auth{
ID: "gemini:apikey:123",
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
},
}
compatAuth := &coreauth.Auth{
ID: "openai-compatibility:bohe:456",
Provider: "bohe",
Label: "bohe",
Attributes: map[string]string{
"api_key": "shared-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
}
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
t.Fatalf("register gemini auth: %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
t.Fatalf("register compat auth: %v", errRegister)
}
geminiIndex := geminiAuth.EnsureIndex()
compatIndex := compatAuth.EnsureIndex()
if geminiIndex == compatIndex {
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
} }
h := &Handler{authManager: manager} h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil { gotGemini := h.authByIndex(geminiIndex)
t.Fatalf("resolveTokenForAuth: %v", err) if gotGemini == nil {
t.Fatal("expected gemini auth by index")
} }
if token != "new-token" { if gotGemini.ID != geminiAuth.ID {
t.Fatalf("expected refreshed token, got %q", token) t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
} }
updated, ok := manager.GetByID(auth.ID) gotCompat := h.authByIndex(compatIndex)
if !ok || updated == nil { if gotCompat == nil {
t.Fatalf("expected auth in manager after update") t.Fatal("expected compat auth by index")
} }
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" { if gotCompat.ID != compatAuth.ID {
t.Fatalf("expected manager metadata updated, got %q", got) t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,197 @@
package management
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
files := []struct {
name string
content string
}{
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
}
for _, file := range files {
fullPath := filepath.Join(authDir, file.name)
data, err := os.ReadFile(fullPath)
if err != nil {
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
}
if string(data) != file.content {
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
}
}
auths := manager.List()
if len(auths) != len(files) {
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
}
}
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
existingName := "alpha.json"
existingContent := `{"type":"codex","email":"alpha@example.com"}`
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
t.Fatalf("failed to seed existing auth file: %v", err)
}
files := []struct {
name string
content string
}{
{name: existingName, content: `{"type":"codex"`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusMultiStatus {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
}
data, err := os.ReadFile(filepath.Join(authDir, existingName))
if err != nil {
t.Fatalf("expected existing auth file to remain readable: %v", err)
}
if string(data) != existingContent {
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
}
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
if err != nil {
t.Fatalf("expected valid auth file to be created: %v", err)
}
if string(betaData) != files[1].content {
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
}
}
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
files := []string{"alpha.json", "beta.json"}
for _, name := range files {
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
t.Fatalf("failed to write auth file %s: %v", name, err)
}
}
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(
http.MethodDelete,
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
nil,
)
ctx.Request = req
h.DeleteAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
}
for _, name := range files {
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
}
}
}

View File

@@ -0,0 +1,164 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
}
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v4/user":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 42,
"username": "gitlab-user",
"name": "GitLab User",
"email": "gitlab@example.com",
})
case "/api/v4/personal_access_tokens/self":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 7,
"name": "management-center",
"scopes": []string{"api", "read_user"},
"user_id": 42,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1893456000,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
store := &memoryAuthStore{}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
h.tokenStore = store
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.RequestGitLabPATToken(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if got := resp["status"]; got != "ok" {
t.Fatalf("status = %#v, want ok", got)
}
if got := resp["model_provider"]; got != "anthropic" {
t.Fatalf("model_provider = %#v, want anthropic", got)
}
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
}
store.mu.Lock()
defer store.mu.Unlock()
if len(store.items) != 1 {
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
}
var saved *coreauth.Auth
for _, item := range store.items {
saved = item
}
if saved == nil {
t.Fatal("expected saved auth record")
}
if saved.Provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", saved.Provider)
}
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
}
if got := saved.Metadata["model_provider"]; got != "anthropic" {
t.Fatalf("saved model_provider = %#v, want anthropic", got)
}
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
}
}
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
state := "gitlab-state-123"
RegisterOAuthSession(state, "gitlab")
t.Cleanup(func() { CompleteOAuthSession(state) })
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.PostOAuthCallback(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
data, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("read callback file: %v", err)
}
var payload map[string]string
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("decode callback payload: %v", err)
}
if got := payload["code"]; got != "test-code" {
t.Fatalf("callback code = %q, want test-code", got)
}
if got := payload["state"]; got != state {
t.Fatalf("callback state = %q, want %q", got, state)
}
}
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
provider, err := NormalizeOAuthProvider("gitlab")
if err != nil {
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
}
if provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", provider)
}
}

View File

@@ -509,8 +509,12 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
} }
for i := range arr { for i := range arr {
normalizeVertexCompatKey(&arr[i]) normalizeVertexCompatKey(&arr[i])
if arr[i].APIKey == "" {
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
return
}
} }
h.cfg.VertexCompatAPIKey = arr h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
h.cfg.SanitizeVertexCompatKeys() h.cfg.SanitizeVertexCompatKeys()
h.persist(c) h.persist(c)
} }

View File

@@ -228,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "anthropic", nil return "anthropic", nil
case "codex", "openai": case "codex", "openai":
return "codex", nil return "codex", nil
case "gitlab":
return "gitlab", nil
case "gemini", "google": case "gemini", "google":
return "gemini", nil return "gemini", nil
case "iflow", "i-flow": case "iflow", "i-flow":

View File

@@ -0,0 +1,49 @@
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}

View File

@@ -403,6 +403,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
}) })
s.engine.GET("/gitlab/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gitlab", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/google/callback", func(c *gin.Context) { s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
@@ -658,6 +672,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken) mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gitlab-auth-url", s.mgmt.RequestGitLabToken)
mgmt.POST("/gitlab-auth-url", s.mgmt.RequestGitLabPATToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)

View File

@@ -4,12 +4,12 @@ package claude
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
tls "github.com/refraction-networking/utls" tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support // newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper { func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" { if cfg != nil {
proxyURL, err := url.Parse(cfg.ProxyURL) proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if err != nil { if errBuild != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err) log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else { } else if mode != proxyutil.ModeInherit && proxyDialer != nil {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct) dialer = proxyDialer
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
} }
} }

View File

@@ -0,0 +1,335 @@
package codebuddy
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const (
BaseURL = "https://copilot.tencent.com"
DefaultDomain = "www.codebuddy.cn"
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
codeBuddyStatePath = "/v2/plugin/auth/state"
codeBuddyTokenPath = "/v2/plugin/auth/token"
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
pollInterval = 5 * time.Second
maxPollDuration = 5 * time.Minute
codeLoginPending = 11217
codeSuccess = 0
)
type CodeBuddyAuth struct {
httpClient *http.Client
cfg *config.Config
baseURL string
}
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
httpClient := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
}
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
}
// AuthState holds the state and auth URL returned by the auth state API.
type AuthState struct {
State string
AuthURL string
}
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
}
requestID := uuid.NewString()
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", "copilot.tencent.com")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Request-ID", requestID)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy auth state: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
State string `json:"state"`
AuthURL string `json:"authUrl"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
}
return &AuthState{
State: result.Data.State,
AuthURL: result.Data.AuthURL,
}, nil
}
type pollResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
RequestID string `json:"requestId"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
// doPollRequest performs a single polling request, safely reading and closing the response body
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
}
a.applyPollHeaders(req)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, 0, err
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy poll: close body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
}
return body, resp.StatusCode, nil
}
// PollForToken polls until the user completes browser authorization and returns auth data.
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
deadline := time.Now().Add(maxPollDuration)
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
}
body, statusCode, err := a.doPollRequest(ctx, pollURL)
if err != nil {
log.Debugf("codebuddy poll: request error: %v", err)
continue
}
if statusCode != http.StatusOK {
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
continue
}
var result pollResponse
if err := json.Unmarshal(body, &result); err != nil {
continue
}
switch result.Code {
case codeSuccess:
if result.Data == nil {
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
}
userID, _ := a.DecodeUserID(result.Data.AccessToken)
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
TokenType: result.Data.TokenType,
Domain: result.Data.Domain,
UserID: userID,
Type: "codebuddy",
}, nil
case codeLoginPending:
// continue polling
default:
// TODO: when the CodeBuddy API error code for user denial is known,
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
}
}
return nil, ErrPollingTimeout
}
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return "", ErrJWTDecodeFailed
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
if claims.Sub == "" {
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
}
return claims.Sub, nil
}
// RefreshToken exchanges a refresh token for a new access token.
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
if domain == "" {
domain = DefaultDomain
}
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
}
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Refresh-Token", refreshToken)
req.Header.Set("X-Auth-Refresh-Source", "plugin")
req.Header.Set("X-Request-ID", requestID)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy refresh: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil {
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
}
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
if newUserID == "" {
newUserID = userID
}
tokenDomain := result.Data.Domain
if tokenDomain == "" {
tokenDomain = domain
}
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
RefreshExpiresIn: result.Data.RefreshExpiresIn,
TokenType: result.Data.TokenType,
Domain: tokenDomain,
UserID: newUserID,
Type: "codebuddy",
}, nil
}
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
}

View File

@@ -0,0 +1,285 @@
package codebuddy
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
func newTestAuth(serverURL string) *CodeBuddyAuth {
return &CodeBuddyAuth{
httpClient: http.DefaultClient,
baseURL: serverURL,
}
}
// fakeJWT builds a minimal JWT with the given sub claim for testing.
func fakeJWT(sub string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return header + "." + encodedPayload + ".sig"
}
// --- FetchAuthState tests ---
func TestFetchAuthState_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyStatePath {
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
}
if got := r.URL.Query().Get("platform"); got != "CLI" {
t.Errorf("expected platform=CLI, got %s", got)
}
if got := r.Header.Get("User-Agent"); got != UserAgent {
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "test-state-abc",
"authUrl": "https://example.com/login?state=test-state-abc",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
result, err := auth.FetchAuthState(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.State != "test-state-abc" {
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
}
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
t.Errorf("unexpected authURL: %s", result.AuthURL)
}
}
func TestFetchAuthState_NonOKStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-200 status")
}
}
func TestFetchAuthState_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 10001,
"msg": "rate limited",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-zero code")
}
}
func TestFetchAuthState_MissingData(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "",
"authUrl": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for empty state/authUrl")
}
}
// --- RefreshToken tests ---
func TestRefreshToken_Success(t *testing.T) {
newAccessToken := fakeJWT("refreshed-user-456")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyRefreshPath {
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
}
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
}
if got := r.Header.Get("X-User-Id"); got != "user-123" {
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
}
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": newAccessToken,
"refreshToken": "new-refresh-token",
"expiresIn": 3600,
"refreshExpiresIn": 86400,
"tokenType": "bearer",
"domain": "custom.domain.com",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.AccessToken != newAccessToken {
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
}
if storage.RefreshToken != "new-refresh-token" {
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
}
if storage.UserID != "refreshed-user-456" {
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
}
if storage.ExpiresIn != 3600 {
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
}
if storage.RefreshExpiresIn != 86400 {
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
}
if storage.Domain != "custom.domain.com" {
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
}
if storage.Type != "codebuddy" {
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
}
}
func TestRefreshToken_DefaultDomain(t *testing.T) {
var receivedDomain string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedDomain = r.Header.Get("X-Domain")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": fakeJWT("user-1"),
"refreshToken": "rt",
"expiresIn": 3600,
"tokenType": "bearer",
"domain": DefaultDomain,
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if receivedDomain != DefaultDomain {
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
}
}
func TestRefreshToken_Unauthorized(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 401 response")
}
}
func TestRefreshToken_Forbidden(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 403 response")
}
}
func TestRefreshToken_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 40001,
"msg": "invalid refresh token",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for non-zero API code")
}
}
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
// When the response domain is empty, it should fall back to the request domain.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": "not-a-valid-jwt",
"refreshToken": "new-rt",
"expiresIn": 7200,
"tokenType": "bearer",
"domain": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.UserID != "original-uid" {
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
}
if storage.Domain != "original.domain.com" {
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
}
}

View File

@@ -0,0 +1,22 @@
package codebuddy_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
)
func TestDecodeUserID_ValidJWT(t *testing.T) {
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
auth := codebuddy.NewCodeBuddyAuth(nil)
userID, err := auth.DecodeUserID(token)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if userID != "test-user-id-123" {
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
}
}

View File

@@ -0,0 +1,25 @@
package codebuddy
import "errors"
var (
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
ErrAccessDenied = errors.New("codebuddy: access denied by user")
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
)
func GetUserFriendlyMessage(err error) string {
switch {
case errors.Is(err, ErrPollingTimeout):
return "Authentication timed out. Please try again."
case errors.Is(err, ErrAccessDenied):
return "Access denied. Please try again and approve the login request."
case errors.Is(err, ErrJWTDecodeFailed):
return "Failed to decode token. Please try logging in again."
case errors.Is(err, ErrTokenFetchFailed):
return "Failed to fetch token from server. Please try again."
default:
return "Authentication failed: " + err.Error()
}
}

View File

@@ -0,0 +1,65 @@
// Package codebuddy provides authentication and token management functionality
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
package codebuddy
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
// for managing access tokens and user account information.
type CodeBuddyTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the number of seconds until the access token expires.
ExpiresIn int64 `json:"expires_in"`
// RefreshExpiresIn is the number of seconds until the refresh token expires.
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
// TokenType is the type of token, typically "bearer".
TokenType string `json:"token_type"`
// Domain is the CodeBuddy service domain/region.
Domain string `json:"domain"`
// UserID is the user ID associated with this token.
UserID string `json:"user_id"`
// Type indicates the authentication provider type, always "codebuddy" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
s.Type = "codebuddy"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(s); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}

View File

@@ -10,9 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
@@ -20,9 +18,9 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
} }
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided. transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
proxyURL, err := url.Parse(cfg.ProxyURL) if errBuild != nil {
if err == nil { log.Errorf("%v", errBuild)
var transport *http.Transport } else if transport != nil {
if proxyURL.Scheme == "socks5" { proxyClient := &http.Client{Transport: transport}
// Handle SOCKS5 proxy. ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
} }
var err error
// Configure the OAuth2 client. // Configure the OAuth2 client.
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: ClientID, ClientID: ClientID,
@@ -327,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
defer manualPromptTimer.Stop() defer manualPromptTimer.Stop()
} }
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback: waitForCallback:
for { for {
select { select {
@@ -348,13 +329,14 @@ waitForCallback:
return nil, err return nil, err
default: default:
} }
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil { continue
return nil, err case input := <-manualInputCh:
} manualInputCh = nil
parsed, err := misc.ParseOAuthCallback(input) manualInputErrCh = nil
if err != nil { parsed, errParse := misc.ParseOAuthCallback(input)
return nil, err if errParse != nil {
return nil, errParse
} }
if parsed == nil { if parsed == nil {
continue continue
@@ -367,6 +349,8 @@ waitForCallback:
} }
authCode = parsed.Code authCode = parsed.Code
break waitForCallback break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C: case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out") return nil, fmt.Errorf("oauth flow timed out")
} }

View File

@@ -5,8 +5,7 @@ import (
) )
// newAuthManager creates a new authentication manager instance with all supported // newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for // authenticators and a file-based token store.
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
// //
// Returns: // Returns:
// - *sdkAuth.Manager: A configured authentication manager instance // - *sdkAuth.Manager: A configured authentication manager instance
@@ -24,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewGitHubCopilotAuthenticator(), sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(), sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(), sdkAuth.NewGitLabAuthenticator(),
sdkAuth.NewCodeBuddyAuthenticator(),
) )
return manager return manager
} }

View File

@@ -0,0 +1,43 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
// It initiates the OAuth authentication, displays the user code for the user to enter
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
//
// Parameters:
// - cfg: The application configuration containing proxy and auth directory settings
// - options: Login options including browser behavior settings
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
}
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
if err != nil {
log.Errorf("CodeBuddy authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("CodeBuddy authentication successful!")
}

View File

@@ -0,0 +1,55 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
claude-header-defaults:
user-agent: " claude-cli/2.1.70 (external, cli) "
package-version: " 0.80.0 "
runtime-version: " v24.5.0 "
os: " MacOS "
arch: " arm64 "
timeout: " 900 "
stabilize-device-profile: false
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
}
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
}
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
}
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
t.Fatalf("OS = %q, want %q", got, "MacOS")
}
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
t.Fatalf("Arch = %q, want %q", got, "arm64")
}
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
t.Fatalf("Timeout = %q, want %q", got, "900")
}
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
}
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
}
}

View File

@@ -0,0 +1,32 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
codex-header-defaults:
user-agent: " my-codex-client/1.0 "
beta-features: " feature-a,feature-b "
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
}
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
}
}

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -101,6 +102,10 @@ type Config struct {
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file. // Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
// These are used only when the client does not send its own headers.
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file. // ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"` ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
@@ -141,13 +146,27 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"` legacyMigrationPending bool `yaml:"-" json:"-"`
} }
// ClaudeHeaderDefaults configures default header values injected into Claude API requests // ClaudeHeaderDefaults configures default header values injected into Claude API requests.
// when the client does not send them. Update these when Claude Code releases a new version. // In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
// profiles are enabled, OS/Arch become the pinned platform baseline, while
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
type ClaudeHeaderDefaults struct { type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"` UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"` PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"` RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"` OS string `yaml:"os" json:"os"`
Arch string `yaml:"arch" json:"arch"`
Timeout string `yaml:"timeout" json:"timeout"`
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
}
// CodexHeaderDefaults configures fallback header values injected into Codex
// model requests for OAuth/file-backed auth when the client omits them.
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
type CodexHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
} }
// TLSConfig holds HTTPS server settings. // TLSConfig holds HTTPS server settings.
@@ -556,6 +575,10 @@ type OpenAICompatibilityModel struct {
// Alias is the model name alias that clients will use to reference this model. // Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"` Alias string `yaml:"alias" json:"alias"`
// Thinking configures the thinking/reasoning capability for this model.
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
} }
func (m OpenAICompatibilityModel) GetName() string { return m.Name } func (m OpenAICompatibilityModel) GetName() string { return m.Name }
@@ -673,12 +696,18 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Gemini API key configuration and migrate legacy entries. // Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys() cfg.SanitizeGeminiKeys()
// Sanitize Vertex-compatible API keys: drop entries without base-url // Sanitize Vertex-compatible API keys.
cfg.SanitizeVertexCompatKeys() cfg.SanitizeVertexCompatKeys()
// Sanitize Codex keys: drop entries without base-url // Sanitize Codex keys: drop entries without base-url
cfg.SanitizeCodexKeys() cfg.SanitizeCodexKeys()
// Sanitize Codex header defaults.
cfg.SanitizeCodexHeaderDefaults()
// Sanitize Claude header defaults.
cfg.SanitizeClaudeHeaderDefaults()
// Sanitize Claude key headers // Sanitize Claude key headers
cfg.SanitizeClaudeKeys() cfg.SanitizeClaudeKeys()
@@ -771,6 +800,30 @@ func payloadRawString(value any) ([]byte, bool) {
} }
} }
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
// configured Codex header fallback values.
func (cfg *Config) SanitizeCodexHeaderDefaults() {
if cfg == nil {
return
}
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
// configured Claude fingerprint baseline values.
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
if cfg == nil {
return
}
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases. // SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries, // It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel. // allows multiple aliases per upstream name, and ensures aliases are unique within each channel.

View File

@@ -20,9 +20,9 @@ type VertexCompatKey struct {
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro"). // Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL is the base URL for the Vertex-compatible API endpoint. // BaseURL optionally overrides the Vertex-compatible API endpoint.
// The executor will append "/v1/publishers/google/models/{model}:action" to this. // The executor will append "/v1/publishers/google/models/{model}:action" to this.
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." // When empty, requests fall back to the default Vertex API base URL.
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this API key. // ProxyURL optionally overrides the global proxy for this API key.
@@ -71,10 +71,6 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
} }
entry.Prefix = normalizeModelPrefix(entry.Prefix) entry.Prefix = normalizeModelPrefix(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.BaseURL = strings.TrimSpace(entry.BaseURL)
if entry.BaseURL == "" {
// BaseURL is required for Vertex API key entries
continue
}
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers) entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)

View File

@@ -30,6 +30,23 @@ type OAuthCallback struct {
ErrorDescription string ErrorDescription string
} }
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
// the result. The returned channels are buffered (size 1) so the goroutine can
// complete even if the caller abandons the channels.
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
inputCh := make(chan string, 1)
errCh := make(chan error, 1)
go func() {
input, err := promptFn(message)
if err != nil {
errCh <- err
return
}
inputCh <- input
}()
return inputCh, errCh
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL. // ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty. // It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) { func ParseOAuthCallback(input string) (*OAuthCallback, error) {

View File

@@ -3,32 +3,24 @@
package registry package registry
import ( import (
"sort"
"strings" "strings"
) )
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport `json:"thinking,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
// staticModelsJSON mirrors the top-level structure of models.json. // staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct { type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"` Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"` Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"` Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"` GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"` AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"` CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"` CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"` CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"` CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"` Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"` IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"` Kimi []*ModelInfo `json:"kimi"`
Antigravity map[string]*AntigravityModelConfig `json:"antigravity"` Antigravity []*ModelInfo `json:"antigravity"`
} }
// GetClaudeModels returns the standard Claude model definitions. // GetClaudeModels returns the standard Claude model definitions.
@@ -91,33 +83,90 @@ func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi) return cloneModelInfos(getModels().Kimi)
} }
// GetAntigravityModelConfig returns static configuration for antigravity models. // GetAntigravityModels returns the standard Antigravity model definitions.
// Keys use upstream model names returned by the Antigravity models endpoint. func GetAntigravityModels() []*ModelInfo {
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { return cloneModelInfos(getModels().Antigravity)
data := getModels()
if len(data.Antigravity) == 0 {
return nil
}
out := make(map[string]*AntigravityModelConfig, len(data.Antigravity))
for k, v := range data.Antigravity {
out[k] = cloneAntigravityModelConfig(v)
}
return out
} }
func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig { // GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
if cfg == nil { // These models are served through the copilot.tencent.com API.
return nil func GetCodeBuddyModels() []*ModelInfo {
now := int64(1748044800) // 2025-05-24
return []*ModelInfo{
{
ID: "glm-5.0",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5.0",
Description: "GLM-5.0 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-4.7",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-4.7",
Description: "GLM-4.7 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "minimax-m2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "MiniMax M2.5",
Description: "MiniMax M2.5 via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "kimi-k2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Kimi K2.5",
Description: "Kimi K2.5 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "deepseek-v3-2-volc",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "DeepSeek V3.2 (Volc)",
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "hunyuan-2.0-thinking",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Hunyuan 2.0 Thinking",
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{ZeroAllowed: true},
SupportedEndpoints: []string{"/chat/completions"},
},
} }
copyConfig := *cfg
if cfg.Thinking != nil {
copyThinking := *cfg.Thinking
if len(cfg.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...)
}
copyConfig.Thinking = &copyThinking
}
return &copyConfig
} }
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned. // cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
@@ -145,7 +194,6 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
// - qwen // - qwen
// - iflow // - iflow
// - kimi // - kimi
// - kiro
// - kilo // - kilo
// - github-copilot // - github-copilot
// - amazonq // - amazonq
@@ -180,28 +228,9 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "amazonq": case "amazonq":
return GetAmazonQModels() return GetAmazonQModels()
case "antigravity": case "antigravity":
cfg := GetAntigravityModelConfig() return GetAntigravityModels()
if len(cfg) == 0 { case "codebuddy":
return nil return GetCodeBuddyModels()
}
models := make([]*ModelInfo, 0, len(cfg))
for modelID, entry := range cfg {
if modelID == "" || entry == nil {
continue
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
Thinking: entry.Thinking,
MaxCompletionTokens: entry.MaxCompletionTokens,
})
}
sort.Slice(models, func(i, j int) bool {
return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
})
return models
default: default:
return nil return nil
} }
@@ -225,10 +254,12 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
data.Qwen, data.Qwen,
data.IFlow, data.IFlow,
data.Kimi, data.Kimi,
data.Antigravity,
GetGitHubCopilotModels(), GetGitHubCopilotModels(),
GetKiroModels(), GetKiroModels(),
GetKiloModels(), GetKiloModels(),
GetAmazonQModels(), GetAmazonQModels(),
GetCodeBuddyModels(),
} }
for _, models := range allModels { for _, models := range allModels {
for _, m := range models { for _, m := range models {
@@ -238,15 +269,6 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
} }
} }
// Check Antigravity static config
if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil {
return &ModelInfo{
ID: modelID,
Thinking: cfg.Thinking,
MaxCompletionTokens: cfg.MaxCompletionTokens,
}
}
return nil return nil
} }
@@ -427,6 +449,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"}, SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
}, },
{
ID: "gpt-5.4",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4",
Description: "OpenAI GPT-5.4 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{ {
ID: "claude-haiku-4.5", ID: "claude-haiku-4.5",
Object: "model", Object: "model",

View File

@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
// Values are interpreted in provider-native token units. // Values are interpreted in provider-native token units.
type ThinkingSupport struct { type ThinkingSupport struct {
// Min is the minimum allowed thinking budget (inclusive). // Min is the minimum allowed thinking budget (inclusive).
Min int `json:"min,omitempty"` Min int `json:"min,omitempty" yaml:"min,omitempty"`
// Max is the maximum allowed thinking budget (inclusive). // Max is the maximum allowed thinking budget (inclusive).
Max int `json:"max,omitempty"` Max int `json:"max,omitempty" yaml:"max,omitempty"`
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking). // ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
ZeroAllowed bool `json:"zero_allowed,omitempty"` ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget). // DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
DynamicAllowed bool `json:"dynamic_allowed,omitempty"` DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high"). // Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
// When set, the model uses level-based reasoning instead of token budgets. // When set, the model uses level-based reasoning instead of token budgets.
Levels []string `json:"levels,omitempty"` Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
} }
// ModelRegistration tracks a model's availability // ModelRegistration tracks a model's availability
@@ -189,6 +189,7 @@ func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
} }
const defaultModelRegistryHookTimeout = 5 * time.Second const defaultModelRegistryHookTimeout = 5 * time.Second
const modelQuotaExceededWindow = 5 * time.Minute
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) { func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook hook := r.hook
@@ -390,6 +391,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
reg.InfoByProvider[provider] = cloneModelInfo(model) reg.InfoByProvider[provider] = cloneModelInfo(model)
} }
reg.LastUpdated = now reg.LastUpdated = now
// Re-registering an existing client/model binding starts a fresh registry
// snapshot for that binding. Cooldown and suspension are transient
// scheduling state and must not survive this reconciliation step.
if reg.QuotaExceededClients != nil { if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID) delete(reg.QuotaExceededClients, clientID)
} }
@@ -783,7 +787,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) { func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models)) models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time var expiresAt time.Time
for _, registration := range r.models { for _, registration := range r.models {
@@ -794,7 +797,7 @@ func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.
if quotaTime == nil { if quotaTime == nil {
continue continue
} }
recoveryAt := quotaTime.Add(quotaExpiredDuration) recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
if now.Before(recoveryAt) { if now.Before(recoveryAt) {
expiredClients++ expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) { if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
@@ -929,7 +932,6 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
return nil return nil
} }
quotaExpiredDuration := 5 * time.Minute
now := time.Now() now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels)) result := make([]*ModelInfo, 0, len(providerModels))
@@ -951,7 +953,7 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider { if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue continue
} }
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++ expiredClients++
} }
} }
@@ -1005,12 +1007,11 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
if registration, exists := r.models[modelID]; exists { if registration, exists := r.models[modelID]; exists {
now := time.Now() now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet // Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0 expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients { for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++ expiredClients++
} }
} }
@@ -1236,12 +1237,11 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
defer r.mutex.Unlock() defer r.mutex.Unlock()
now := time.Now() now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false invalidated := false
for modelID, registration := range r.models { for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients { for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
delete(registration.QuotaExceededClients, clientID) delete(registration.QuotaExceededClients, clientID)
invalidated = true invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)

View File

@@ -15,7 +15,8 @@ import (
) )
const ( const (
modelsFetchTimeout = 30 * time.Second modelsFetchTimeout = 30 * time.Second
modelsRefreshInterval = 3 * time.Hour
) )
var modelsURLs = []string{ var modelsURLs = []string{
@@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once var updaterOnce sync.Once
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
// changedProviders contains the provider names whose model definitions changed.
type ModelRefreshCallback func(changedProviders []string)
var (
refreshCallbackMu sync.Mutex
refreshCallback ModelRefreshCallback
pendingRefreshChanges []string
)
// SetModelRefreshCallback registers a callback that is invoked when startup or
// periodic model refresh detects changes. Only one callback is supported;
// subsequent calls replace the previous callback.
func SetModelRefreshCallback(cb ModelRefreshCallback) {
refreshCallbackMu.Lock()
refreshCallback = cb
var pending []string
if cb != nil && len(pendingRefreshChanges) > 0 {
pending = append([]string(nil), pendingRefreshChanges...)
pendingRefreshChanges = nil
}
refreshCallbackMu.Unlock()
if cb != nil && len(pending) > 0 {
cb(pending)
}
}
func init() { func init() {
// Load embedded data as fallback on startup. // Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil { if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
@@ -42,23 +71,76 @@ func init() {
} }
} }
// StartModelsUpdater runs a one-time models refresh on startup. // StartModelsUpdater starts a background updater that fetches models
// It blocks until the startup fetch attempt finishes so service initialization // immediately on startup and then refreshes the model catalog every 3 hours.
// can wait for the refreshed catalog before registering auth-backed models. // Safe to call multiple times; only one updater will run.
// Safe to call multiple times; only one refresh will run.
func StartModelsUpdater(ctx context.Context) { func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() { updaterOnce.Do(func() {
runModelsUpdater(ctx) go runModelsUpdater(ctx)
}) })
} }
func runModelsUpdater(ctx context.Context) { func runModelsUpdater(ctx context.Context) {
// Try network fetch once on startup, then stop. tryStartupRefresh(ctx)
// Periodic refresh is disabled - models are only refreshed at startup. periodicRefresh(ctx)
tryRefreshModels(ctx)
} }
func tryRefreshModels(ctx context.Context) { func periodicRefresh(ctx context.Context) {
ticker := time.NewTicker(modelsRefreshInterval)
defer ticker.Stop()
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tryPeriodicRefresh(ctx)
}
}
}
// tryPeriodicRefresh fetches models from remote, compares with the current
// catalog, and notifies the registered callback if any provider changed.
func tryPeriodicRefresh(ctx context.Context) {
tryRefreshModels(ctx, "periodic model refresh")
}
// tryStartupRefresh fetches models from remote in the background during
// process startup. It uses the same change detection as periodic refresh so
// existing auth registrations can be updated after the callback is registered.
func tryStartupRefresh(ctx context.Context) {
tryRefreshModels(ctx, "startup model refresh")
}
func tryRefreshModels(ctx context.Context, label string) {
oldData := getModels()
parsed, url := fetchModelsFromRemote(ctx)
if parsed == nil {
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
return
}
// Detect changes before updating store.
changed := detectChangedProviders(oldData, parsed)
// Update store with new data regardless.
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = parsed
modelsCatalogStore.mu.Unlock()
if len(changed) == 0 {
log.Infof("%s completed from %s, no changes detected", label, url)
return
}
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
notifyModelRefresh(changed)
}
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
client := &http.Client{Timeout: modelsFetchTimeout} client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs { for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout) reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
@@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) {
continue continue
} }
if err := loadModelsFromBytes(data, url); err != nil { var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
log.Warnf("models parse failed from %s: %v", url, err) log.Warnf("models parse failed from %s: %v", url, err)
continue continue
} }
if err := validateModelsCatalog(&parsed); err != nil {
log.Warnf("models validate failed from %s: %v", url, err)
continue
}
log.Infof("models updated from %s", url) return &parsed, url
}
return nil, ""
}
// detectChangedProviders compares two model catalogs and returns provider names
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
// under a single "codex" provider.
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if oldData == nil || newData == nil {
return nil
}
type section struct {
provider string
oldList []*ModelInfo
newList []*ModelInfo
}
sections := []section{
{"claude", oldData.Claude, newData.Claude},
{"gemini", oldData.Gemini, newData.Gemini},
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
}
seen := make(map[string]bool, len(sections))
var changed []string
for _, s := range sections {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
changed = append(changed, s.provider)
seen[s.provider] = true
}
}
return changed
}
// modelSectionChanged reports whether two model slices differ.
func modelSectionChanged(a, b []*ModelInfo) bool {
if len(a) != len(b) {
return true
}
if len(a) == 0 {
return false
}
aj, err1 := json.Marshal(a)
bj, err2 := json.Marshal(b)
if err1 != nil || err2 != nil {
return true
}
return string(aj) != string(bj)
}
func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return return
} }
log.Warn("models refresh failed from all URLs, using current data")
refreshCallbackMu.Lock()
cb := refreshCallback
if cb == nil {
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
refreshCallbackMu.Unlock()
return
}
refreshCallbackMu.Unlock()
cb(changedProviders)
}
func mergeProviderNames(existing, incoming []string) []string {
if len(incoming) == 0 {
return existing
}
seen := make(map[string]struct{}, len(existing)+len(incoming))
merged := make([]string, 0, len(existing)+len(incoming))
for _, provider := range existing {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
for _, provider := range incoming {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
} }
func loadModelsFromBytes(data []byte, source string) error { func loadModelsFromBytes(data []byte, source string) error {
@@ -145,6 +338,7 @@ func validateModelsCatalog(data *staticModelsJSON) error {
{name: "qwen", models: data.Qwen}, {name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow}, {name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi}, {name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity},
} }
for _, section := range requiredSections { for _, section := range requiredSections {
@@ -152,9 +346,6 @@ func validateModelsCatalog(data *staticModelsJSON) error {
return err return err
} }
} }
if err := validateAntigravitySection(data.Antigravity); err != nil {
return err
}
return nil return nil
} }
@@ -179,20 +370,3 @@ func validateModelSection(section string, models []*ModelInfo) error {
} }
return nil return nil
} }
func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error {
if len(configs) == 0 {
return fmt.Errorf("antigravity section is empty")
}
for modelID, cfg := range configs {
trimmedID := strings.TrimSpace(modelID)
if trimmedID == "" {
return fmt.Errorf("antigravity contains empty model id")
}
if cfg == nil {
return fmt.Errorf("antigravity[%q] is null", trimmedID)
}
}
return nil
}

View File

@@ -2481,40 +2481,83 @@
} }
} }
], ],
"antigravity": { "antigravity": [
"claude-opus-4-6-thinking": { {
"id": "claude-opus-4-6-thinking",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Claude Opus 4.6 (Thinking)",
"name": "claude-opus-4-6-thinking",
"description": "Claude Opus 4.6 (Thinking)",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": { "thinking": {
"min": 1024, "min": 1024,
"max": 64000, "max": 64000,
"zero_allowed": true, "zero_allowed": true,
"dynamic_allowed": true "dynamic_allowed": true
}, }
"max_completion_tokens": 64000
}, },
"claude-sonnet-4-6": { {
"id": "claude-sonnet-4-6",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Claude Sonnet 4.6 (Thinking)",
"name": "claude-sonnet-4-6",
"description": "Claude Sonnet 4.6 (Thinking)",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": { "thinking": {
"min": 1024, "min": 1024,
"max": 64000, "max": 64000,
"zero_allowed": true, "zero_allowed": true,
"dynamic_allowed": true "dynamic_allowed": true
}, }
"max_completion_tokens": 64000
}, },
"gemini-2.5-flash": { {
"id": "gemini-2.5-flash",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 2.5 Flash",
"name": "gemini-2.5-flash",
"description": "Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": { "thinking": {
"max": 24576, "max": 24576,
"zero_allowed": true, "zero_allowed": true,
"dynamic_allowed": true "dynamic_allowed": true
} }
}, },
"gemini-2.5-flash-lite": { {
"id": "gemini-2.5-flash-lite",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 2.5 Flash Lite",
"name": "gemini-2.5-flash-lite",
"description": "Gemini 2.5 Flash Lite",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": { "thinking": {
"max": 24576, "max": 24576,
"zero_allowed": true, "zero_allowed": true,
"dynamic_allowed": true "dynamic_allowed": true
} }
}, },
"gemini-3-flash": { {
"id": "gemini-3-flash",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Flash",
"name": "gemini-3-flash",
"description": "Gemini 3 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2527,7 +2570,16 @@
] ]
} }
}, },
"gemini-3-pro-high": { {
"id": "gemini-3-pro-high",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Pro (High)",
"name": "gemini-3-pro-high",
"description": "Gemini 3 Pro (High)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2538,7 +2590,16 @@
] ]
} }
}, },
"gemini-3-pro-low": { {
"id": "gemini-3-pro-low",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Pro (Low)",
"name": "gemini-3-pro-low",
"description": "Gemini 3 Pro (Low)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2549,7 +2610,14 @@
] ]
} }
}, },
"gemini-3.1-flash-image": { {
"id": "gemini-3.1-flash-image",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.1 Flash Image",
"name": "gemini-3.1-flash-image",
"description": "Gemini 3.1 Flash Image",
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2560,18 +2628,16 @@
] ]
} }
}, },
"gemini-3.1-flash-lite-preview": { {
"thinking": { "id": "gemini-3.1-pro-high",
"min": 128, "object": "model",
"max": 32768, "owned_by": "antigravity",
"dynamic_allowed": true, "type": "antigravity",
"levels": [ "display_name": "Gemini 3.1 Pro (High)",
"minimal", "name": "gemini-3.1-pro-high",
"high" "description": "Gemini 3.1 Pro (High)",
] "context_length": 1048576,
} "max_completion_tokens": 65535,
},
"gemini-3.1-pro-high": {
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2582,7 +2648,16 @@
] ]
} }
}, },
"gemini-3.1-pro-low": { {
"id": "gemini-3.1-pro-low",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.1 Pro (Low)",
"name": "gemini-3.1-pro-low",
"description": "Gemini 3.1 Pro (Low)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": { "thinking": {
"min": 128, "min": 128,
"max": 32768, "max": 32768,
@@ -2593,6 +2668,16 @@
] ]
} }
}, },
"gpt-oss-120b-medium": {} {
} "id": "gpt-oss-120b-medium",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "GPT-OSS 120B (Medium)",
"name": "gpt-oss-120b-medium",
"description": "GPT-OSS 120B (Medium)",
"context_length": 114000,
"max_completion_tokens": 32768
}
]
} }

View File

@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param) out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()} resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
return resp, nil return resp, nil
} }
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
} }
break break
} }
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
} }
reporter.publish(ctx, parseGeminiUsage(event.Payload)) reporter.publish(ctx, parseGeminiUsage(event.Payload))
return false return false
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response") return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
} }
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body) translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }
// Refresh refreshes the authentication credentials (no-op for AI Studio). // Refresh refreshes the authentication credentials (no-op for AI Studio).

View File

@@ -24,7 +24,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -43,7 +42,6 @@ const (
antigravityCountTokensPath = "/v1internal:countTokens" antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent" antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent" antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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" antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64" defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
@@ -55,78 +53,8 @@ const (
var ( var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
) )
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream. // AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct { type AntigravityExecutor struct {
cfg *config.Config cfg *config.Config
@@ -380,7 +308,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes)) reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param) converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
return resp, nil return resp, nil
} }
@@ -584,7 +512,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload)) reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param) converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
return resp, nil return resp, nil
@@ -763,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
} }
partsJSON, _ := json.Marshal(parts) partsJSON, _ := json.Marshal(parts)
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON)) updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
responseTemplate = string(updatedTemplate)
if role != "" { if role != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
responseTemplate = string(updatedTemplate)
} }
if finishReason != "" { if finishReason != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
responseTemplate = string(updatedTemplate)
} }
if modelVersion != "" { if modelVersion != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
responseTemplate = string(updatedTemplate)
} }
if responseID != "" { if responseID != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
responseTemplate = string(updatedTemplate)
} }
if usageRaw != "" { if usageRaw != "" {
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw) updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
responseTemplate = string(updatedTemplate)
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() { } else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0) responseTemplate = string(updatedTemplate)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0) updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
responseTemplate = string(updatedTemplate)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
responseTemplate = string(updatedTemplate)
} }
output := `{"response":{},"traceId":""}` output := `{"response":{},"traceId":""}`
output, _ = sjson.SetRaw(output, "response", responseTemplate) updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
output = string(updatedOutput)
if traceID != "" { if traceID != "" {
output, _ = sjson.Set(output, "traceId", traceID) updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
output = string(updatedOutput)
} }
return []byte(output) return []byte(output)
} }
@@ -952,12 +891,12 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param) tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param)
for i := range tail { for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])} out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -1115,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int() count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes) translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
} }
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
@@ -1150,168 +1089,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
} }
} }
// FetchAntigravityModels retrieves available models using the supplied auth.
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
for idx, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
if errReq != nil {
return fallbackAntigravityPrimaryModels()
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(baseURL); host != "" {
httpReq.Host = host
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
modelCfg := modelConfig[modelID]
// Extract displayName from upstream response, fallback to modelID
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
modelInfo := &registry.ModelInfo{
ID: modelID,
Name: modelID,
Description: displayName,
DisplayName: displayName,
Version: modelID,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Build input modalities from upstream capability flags.
inputModalities := []string{"TEXT"}
if modelData.Get("supportsImages").Bool() {
inputModalities = append(inputModalities, "IMAGE")
}
if modelData.Get("supportsVideo").Bool() {
inputModalities = append(inputModalities, "VIDEO")
}
modelInfo.SupportedInputModalities = inputModalities
modelInfo.SupportedOutputModalities = []string{"TEXT"}
// Token limits from upstream.
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
modelInfo.InputTokenLimit = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
modelInfo.OutputTokenLimit = int(maxOut)
}
// Supported generation methods (Gemini v1beta convention).
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
// Look up Thinking support from static config using upstream model name.
if modelCfg != nil {
if modelCfg.Thinking != nil {
modelInfo.Thinking = modelCfg.Thinking
}
if modelCfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) { func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
if auth == nil { if auth == nil {
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
@@ -1499,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// if useAntigravitySchema { // if useAntigravitySchema {
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts") // systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user") // payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction) // payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction)) // payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() { // if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
// for _, partResult := range systemInstructionPartsResult.Array() { // for _, partResult := range systemInstructionPartsResult.Array() {
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw) // payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
// } // }
// } // }
// } // }
if strings.Contains(modelName, "claude") { if strings.Contains(modelName, "claude") {
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
payloadStr = string(updated)
} else { } else {
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens") payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
} }
@@ -1733,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
} }
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte { func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName) template := payload
template, _ = sjson.Set(template, "userAgent", "antigravity") template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
isImageModel := strings.Contains(modelName, "image") isImageModel := strings.Contains(modelName, "image")
@@ -1744,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
} else { } else {
reqType = "agent" reqType = "agent"
} }
template, _ = sjson.Set(template, "requestType", reqType) template, _ = sjson.SetBytes(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback) // Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" { if projectID != "" {
template, _ = sjson.Set(template, "project", projectID) template, _ = sjson.SetBytes(template, "project", projectID)
} else { } else {
template, _ = sjson.Set(template, "project", generateProjectID()) template, _ = sjson.SetBytes(template, "project", generateProjectID())
} }
if isImageModel { if isImageModel {
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID()) template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
} else { } else {
template, _ = sjson.Set(template, "requestId", generateRequestID()) template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload)) template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
} }
template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.DeleteBytes(template, "request.safetySettings")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() { if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw) template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
template, _ = sjson.Delete(template, "toolConfig") template, _ = sjson.DeleteBytes(template, "toolConfig")
} }
return []byte(template) return template
} }
func generateRequestID() string { func generateRequestID() string {

View File

@@ -1,90 +0,0 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -0,0 +1,383 @@
package executor
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
const (
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
defaultClaudeFingerprintPackageVersion = "0.74.0"
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
defaultClaudeFingerprintOS = "MacOS"
defaultClaudeFingerprintArch = "arm64"
claudeDeviceProfileTTL = 7 * 24 * time.Hour
claudeDeviceProfileCleanupPeriod = time.Hour
)
var (
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu sync.RWMutex
claudeDeviceProfileCacheCleanupOnce sync.Once
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
)
type claudeCLIVersion struct {
major int
minor int
patch int
}
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
switch {
case v.major != other.major:
if v.major > other.major {
return 1
}
return -1
case v.minor != other.minor:
if v.minor > other.minor {
return 1
}
return -1
case v.patch != other.patch:
if v.patch > other.patch {
return 1
}
return -1
default:
return 0
}
}
type claudeDeviceProfile struct {
UserAgent string
PackageVersion string
RuntimeVersion string
OS string
Arch string
Version claudeCLIVersion
HasVersion bool
}
type claudeDeviceProfileCacheEntry struct {
profile claudeDeviceProfile
expire time.Time
}
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
return false
}
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
}
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
hdrDefault := func(cfgVal, fallback string) string {
if strings.TrimSpace(cfgVal) != "" {
return strings.TrimSpace(cfgVal)
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
profile := claudeDeviceProfile{
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
}
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
profile.Version = version
profile.HasVersion = true
}
return profile
}
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
if len(matches) != 4 {
return claudeCLIVersion{}, false
}
major, err := strconv.Atoi(matches[1])
if err != nil {
return claudeCLIVersion{}, false
}
minor, err := strconv.Atoi(matches[2])
if err != nil {
return claudeCLIVersion{}, false
}
patch, err := strconv.Atoi(matches[3])
if err != nil {
return claudeCLIVersion{}, false
}
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
}
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.HasVersion {
return false
}
if current.UserAgent == "" || !current.HasVersion {
return true
}
return candidate.Version.Compare(current.Version) > 0
}
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
profile.OS = baseline.OS
profile.Arch = baseline.Arch
return profile
}
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
// baseline platform and enforces the baseline software fingerprint as a floor.
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
profile.UserAgent = baseline.UserAgent
profile.PackageVersion = baseline.PackageVersion
profile.RuntimeVersion = baseline.RuntimeVersion
profile.Version = baseline.Version
profile.HasVersion = baseline.HasVersion
}
return profile
}
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
if headers == nil {
return claudeDeviceProfile{}, false
}
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
version, ok := parseClaudeCLIVersion(userAgent)
if !ok {
return claudeDeviceProfile{}, false
}
baseline := defaultClaudeDeviceProfile(cfg)
profile := claudeDeviceProfile{
UserAgent: userAgent,
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
Version: version,
HasVersion: true,
}
return profile, true
}
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
if headers == nil {
return fallback
}
if value := strings.TrimSpace(headers.Get(name)); value != "" {
return value
}
return fallback
}
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
switch {
case auth != nil && strings.TrimSpace(auth.ID) != "":
return "auth:" + strings.TrimSpace(auth.ID)
case strings.TrimSpace(apiKey) != "":
return "api_key:" + strings.TrimSpace(apiKey)
default:
return "global"
}
}
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
return hex.EncodeToString(sum[:])
}
func startClaudeDeviceProfileCacheCleanup() {
go func() {
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredClaudeDeviceProfiles()
}
}()
}
func purgeExpiredClaudeDeviceProfiles() {
now := time.Now()
claudeDeviceProfileCacheMu.Lock()
for key, entry := range claudeDeviceProfileCache {
if !entry.expire.After(now) {
delete(claudeDeviceProfileCache, key)
}
}
claudeDeviceProfileCacheMu.Unlock()
}
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
now := time.Now()
baseline := defaultClaudeDeviceProfile(cfg)
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
if hasCandidate {
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
}
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
hasCandidate = false
}
claudeDeviceProfileCacheMu.RLock()
entry, hasCached := claudeDeviceProfileCache[cacheKey]
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
claudeDeviceProfileCacheMu.RUnlock()
if hasCandidate {
if claudeDeviceProfileBeforeCandidateStore != nil {
claudeDeviceProfileBeforeCandidateStore(candidate)
}
claudeDeviceProfileCacheMu.Lock()
entry, hasCached = claudeDeviceProfileCache[cacheKey]
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
if cachedValid {
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
}
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
entry.expire = now.Add(claudeDeviceProfileTTL)
claudeDeviceProfileCache[cacheKey] = entry
claudeDeviceProfileCacheMu.Unlock()
return entry.profile
}
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
profile: candidate,
expire: now.Add(claudeDeviceProfileTTL),
}
claudeDeviceProfileCacheMu.Unlock()
return candidate
}
if cachedValid {
claudeDeviceProfileCacheMu.Lock()
entry = claudeDeviceProfileCache[cacheKey]
if entry.expire.After(now) && entry.profile.UserAgent != "" {
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
entry.expire = now.Add(claudeDeviceProfileTTL)
claudeDeviceProfileCache[cacheKey] = entry
claudeDeviceProfileCacheMu.Unlock()
return entry.profile
}
claudeDeviceProfileCacheMu.Unlock()
}
return baseline
}
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
if r == nil {
return
}
for _, headerName := range []string{
"User-Agent",
"X-Stainless-Package-Version",
"X-Stainless-Runtime-Version",
"X-Stainless-Os",
"X-Stainless-Arch",
} {
r.Header.Del(headerName)
}
r.Header.Set("User-Agent", profile.UserAgent)
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
r.Header.Set("X-Stainless-Os", profile.OS)
r.Header.Set("X-Stainless-Arch", profile.Arch)
}
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
if r == nil {
return
}
profile := defaultClaudeDeviceProfile(cfg)
miscEnsure := func(name, fallback string) {
if strings.TrimSpace(r.Header.Get(name)) != "" {
return
}
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
return
}
r.Header.Set(name, fallback)
}
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
miscEnsure("X-Stainless-Os", mapStainlessOS())
miscEnsure("X-Stainless-Arch", mapStainlessArch())
// Legacy mode preserves per-auth custom header overrides. By the time we get
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
return
}
clientUA := ""
if ginHeaders != nil {
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
return
}
r.Header.Set("User-Agent", profile.UserAgent)
}

View File

@@ -14,7 +14,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/textproto" "net/textproto"
"runtime"
"strings" "strings"
"time" "time"
@@ -255,7 +254,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data, data,
&param, &param,
) )
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -443,7 +442,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
&param, &param,
) )
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
@@ -561,7 +560,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int() count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
} }
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil return body, nil
} }
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) { func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string { hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" { if cfgVal != "" {
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header ginHeaders = ginCtx.Request.Header
} }
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile claudeDeviceProfile
if stabilizeDeviceProfile {
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05" baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" { if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli") misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28). // Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0") misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node") misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js") misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600")) misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
// For User-Agent, only forward the client's header if it's already a Claude Code client.
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
// to avoid leaking the real client identity during cloaking.
clientUA := ""
if ginHeaders != nil {
clientUA = ginHeaders.Get("User-Agent")
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
} else {
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
}
r.Header.Set("Connection", "keep-alive") r.Header.Set("Connection", "keep-alive")
if stream { if stream {
r.Header.Set("Accept", "text/event-stream") r.Header.Set("Accept", "text/event-stream")
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
} }
// Keep OS/Arch mapping dynamic (not configurable). // Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH. // to the configured baseline while still allowing newer official
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
var attrs map[string]string var attrs map[string]string
if auth != nil { if auth != nil {
attrs = auth.Attributes attrs = auth.Attributes
} }
util.ApplyCustomHeadersFromAttrs(r, attrs) util.ApplyCustomHeadersFromAttrs(r, attrs)
if stabilizeDeviceProfile {
applyClaudeDeviceProfileHeaders(r, deviceProfile)
} else {
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
}
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which // Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line // may override it with a user-configured value. Compressed SSE breaks the line
// scanner regardless of user preference, so this is non-negotiable for streams. // scanner regardless of user preference, so this is non-negotiable for streams.
@@ -1260,7 +1224,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta. // TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw partJSON := part.Raw
if !part.Get("cache_control").Exists() { if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral") updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
partJSON = string(updated)
} }
result += "," + partJSON result += "," + partJSON
} }
@@ -1268,7 +1233,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
}) })
} else if system.Type == gjson.String && system.String() != "" { } else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}` partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String()) updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
partJSON = string(updated)
result += "," + partJSON result += "," + partJSON
} }
result += "]" result += "]"

View File

@@ -8,8 +8,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -19,6 +22,587 @@ import (
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
func resetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
}
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
ginReq.Header = incoming.Clone()
ginCtx.Request = ginReq
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
}
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
t.Helper()
if got := headers.Get("User-Agent"); got != userAgent {
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
}
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
}
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
}
if got := headers.Get("X-Stainless-Os"); got != osName {
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
}
if got := headers.Get("X-Stainless-Arch"); got != arch {
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
}
}
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
Timeout: "900",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline",
Attributes: map[string]string{
"api_key": "key-baseline",
"header:User-Agent": "evil-client/9.9",
"header:X-Stainless-Os": "Linux",
"header:X-Stainless-Arch": "x64",
"header:X-Stainless-Package-Version": "9.9.9",
},
}
incoming := http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
}
req := newClaudeHeaderTestRequest(t, incoming)
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
}
}
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-upgrade",
Attributes: map[string]string{
"api_key": "key-upgrade",
},
}
firstReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"lobe-chat/1.0"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
higherReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
"X-Stainless-Os": []string{"MacOS"},
"X-Stainless-Arch": []string{"arm64"},
})
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.73.0"},
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline-floor",
Attributes: map[string]string{
"api_key": "key-baseline-floor",
},
}
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.81.0"},
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
oldCfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
newCfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.77 (external, cli)",
PackageVersion: "0.87.0",
RuntimeVersion: "v24.8.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline-reload",
Attributes: map[string]string{
"api_key": "key-baseline-reload",
},
}
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.81.0"},
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "my-gateway/1.0",
PackageVersion: "custom-pkg",
RuntimeVersion: "custom-runtime",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-custom-baseline-learning",
Attributes: map[string]string{
"api_key": "key-custom-baseline-learning",
},
}
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.87.0"},
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-racy-upgrade",
Attributes: map[string]string{
"api_key": "key-racy-upgrade",
},
}
lowPaused := make(chan struct{})
releaseLow := make(chan struct{})
var pauseOnce sync.Once
var releaseOnce sync.Once
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
return
}
pauseOnce.Do(func() { close(lowPaused) })
<-releaseLow
}
t.Cleanup(func() {
claudeDeviceProfileBeforeCandidateStore = nil
releaseOnce.Do(func() { close(releaseLow) })
})
lowResultCh := make(chan claudeDeviceProfile, 1)
go func() {
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
}, cfg)
}()
select {
case <-lowPaused:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for lower candidate to pause before storing")
}
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
"X-Stainless-Os": []string{"MacOS"},
"X-Stainless-Arch": []string{"arm64"},
}, cfg)
releaseOnce.Do(func() { close(releaseLow) })
select {
case lowResult := <-lowResultCh:
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if lowResult.PackageVersion != "0.75.0" {
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
}
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for lower candidate result")
}
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
}
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"curl/8.7.1"},
}, cfg)
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if cached.PackageVersion != "0.75.0" {
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
}
if cached.OS != "MacOS" || cached.Arch != "arm64" {
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
}
}
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-third-party-then-official",
Attributes: map[string]string{
"api_key": "key-third-party-then-official",
},
}
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.87.0"},
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-disable-stability",
Attributes: map[string]string{
"api_key": "key-disable-stability",
},
}
firstReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"lobe-chat/1.0"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.73.0"},
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
}
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-legacy-ua-override",
Attributes: map[string]string{
"api_key": "key-legacy-ua-override",
"header:User-Agent": "config-ua/1.0",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
}
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-legacy-runtime-os-arch",
Attributes: map[string]string{
"api_key": "key-legacy-runtime-os-arch",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
})
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
}
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
resetClaudeDeviceProfileCache()
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
},
}
auth := &cliproxyauth.Auth{
ID: "auth-unset-runtime-os-arch",
Attributes: map[string]string{
"api_key": "key-unset-runtime-os-arch",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
})
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
}
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
if claudeDeviceProfileStabilizationEnabled(nil) {
t.Fatal("expected nil config to default to disabled stabilization")
}
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
}
}
func TestApplyClaudeToolPrefix(t *testing.T) { func TestApplyClaudeToolPrefix(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`) input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_") out := applyClaudeToolPrefix(input, "proxy_")

View File

@@ -0,0 +1,343 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
)
const (
codeBuddyChatPath = "/v2/chat/completions"
codeBuddyAuthType = "codebuddy"
)
// CodeBuddyExecutor handles requests to the CodeBuddy API.
type CodeBuddyExecutor struct {
cfg *config.Config
}
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
return &CodeBuddyExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
// codeBuddyCredentials extracts the access token and domain from auth metadata.
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
if auth == nil {
return "", "", ""
}
accessToken = metaStringValue(auth.Metadata, "access_token")
userID = metaStringValue(auth.Metadata, "user_id")
domain = metaStringValue(auth.Metadata, "domain")
if domain == "" {
domain = codebuddy.DefaultDomain
}
return
}
// PrepareRequest prepares the HTTP request before execution.
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return fmt.Errorf("codebuddy: missing access token")
}
e.applyHeaders(req, accessToken, userID, domain)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("codebuddy executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, maxScannerBufferSize)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh exchanges the CodeBuddy refresh token for a new access token.
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("codebuddy: missing auth")
}
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
if refreshToken == "" {
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
return auth, nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
if err != nil {
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
}
updated := auth.Clone()
updated.Metadata["access_token"] = storage.AccessToken
if storage.RefreshToken != "" {
updated.Metadata["refresh_token"] = storage.RefreshToken
}
updated.Metadata["expires_in"] = storage.ExpiresIn
updated.Metadata["domain"] = storage.Domain
if storage.UserID != "" {
updated.Metadata["user_id"] = storage.UserID
}
now := time.Now()
updated.UpdatedAt = now
updated.LastRefreshedAt = now
return updated, nil
}
// CountTokens is not supported for CodeBuddy.
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
}
// applyHeaders sets required headers for CodeBuddy API requests.
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", codebuddy.UserAgent)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("X-IDE-Type", "CLI")
req.Header.Set("X-IDE-Name", "CLI")
req.Header.Set("X-IDE-Version", "2.63.2")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}

View File

@@ -28,8 +28,8 @@ import (
) )
const ( const (
codexClientVersion = "0.101.0" codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464" codexOriginator = "codex_cli_rs"
) )
var dataTag = []byte("data:") var dataTag = []byte("data:")
@@ -122,7 +122,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if err != nil { if err != nil {
return resp, err return resp, err
} }
applyCodexHeaders(httpReq, auth, apiKey, true) applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"} err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -226,7 +226,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
if err != nil { if err != nil {
return resp, err return resp, err
} }
applyCodexHeaders(httpReq, auth, apiKey, false) applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
@@ -273,7 +273,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -321,7 +321,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if err != nil { if err != nil {
return nil, err return nil, err
} }
applyCodexHeaders(httpReq, auth, apiKey, true) applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
authID = auth.ID authID = auth.ID
@@ -387,7 +387,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
@@ -432,7 +432,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count) usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON)) translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) { func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
@@ -636,7 +636,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
return httpReq, nil return httpReq, nil
} }
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool) { func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token) r.Header.Set("Authorization", "Bearer "+token)
@@ -645,9 +645,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header ginHeaders = ginCtx.Request.Header
} }
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion) misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString()) misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent) misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if stream { if stream {
r.Header.Set("Accept", "text/event-stream") r.Header.Set("Accept", "text/event-stream")
@@ -662,8 +665,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
isAPIKey = true isAPIKey = true
} }
} }
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
r.Header.Set("Originator", originator)
} else if !isAPIKey {
r.Header.Set("Originator", codexOriginator)
}
if !isAPIKey { if !isAPIKey {
r.Header.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil { if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok { if accountID, ok := auth.Metadata["account_id"].(string); ok {
r.Header.Set("Chatgpt-Account-Id", accountID) r.Header.Set("Chatgpt-Account-Id", accountID)

View File

@@ -23,6 +23,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -190,7 +191,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
if auth != nil { if auth != nil {
@@ -342,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
} }
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)} resp = cliproxyexecutor.Response{Payload: out}
return resp, nil return resp, nil
} }
} }
@@ -385,7 +386,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
} }
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string var authID, authLabel, authType, authValue string
authID = auth.ID authID = auth.ID
@@ -591,7 +592,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
line := encodeCodexWebsocketAsSSE(payload) line := encodeCodexWebsocketAsSSE(payload)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, &param)
for i := range chunks { for i := range chunks {
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) { if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
terminateReason = "context_done" terminateReason = "context_done"
terminateErr = ctx.Err() terminateErr = ctx.Err()
return return
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return dialer return dialer
} }
parsedURL, errParse := url.Parse(proxyURL) setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil { if errParse != nil {
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse) log.Errorf("codex websockets executor: %v", errParse)
return dialer return dialer
} }
switch parsedURL.Scheme { switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5": case "socks5":
var proxyAuth *proxy.Auth var proxyAuth *proxy.Auth
if parsedURL.User != nil { if setting.URL.User != nil {
username := parsedURL.User.Username() username := setting.URL.User.Username()
password, _ := parsedURL.User.Password() password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password} proxyAuth = &proxy.Auth{User: username, Password: password}
} }
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil { if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer return dialer
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return socksDialer.Dial(network, addr) return socksDialer.Dial(network, addr)
} }
case "http", "https": case "http", "https":
dialer.Proxy = http.ProxyURL(parsedURL) dialer.Proxy = http.ProxyURL(setting.URL)
default: default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme) log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
} }
return dialer return dialer
@@ -787,7 +797,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
return rawJSON, headers return rawJSON, headers
} }
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header { func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
if headers == nil { if headers == nil {
headers = http.Header{} headers = http.Header{}
} }
@@ -800,12 +810,14 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header ginHeaders = ginCtx.Request.Header
} }
misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "") cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "") misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta")) betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil { if betaHeader == "" && ginHeaders != nil {
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta")) betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
@@ -815,7 +827,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
} }
headers.Set("OpenAI-Beta", betaHeader) headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent) ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
isAPIKey := false isAPIKey := false
if auth != nil && auth.Attributes != nil { if auth != nil && auth.Attributes != nil {
@@ -823,8 +835,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
isAPIKey = true isAPIKey = true
} }
} }
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
headers.Set("Originator", codexOriginator)
}
if !isAPIKey { if !isAPIKey {
headers.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil { if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok { if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" { if trimmed := strings.TrimSpace(accountID); trimmed != "" {
@@ -843,6 +859,62 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers return headers
} }
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return "", ""
}
}
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
type statusErrWithHeaders struct { type statusErrWithHeaders struct {
statusErr statusErr
headers http.Header headers http.Header

View File

@@ -3,8 +3,13 @@ package executor
import ( import (
"context" "context"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -28,9 +33,259 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
} }
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "") headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
} }
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := headers.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "my-codex-client/1.0",
BetaFeatures: "feature-a,feature-b",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
}
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := http.Header{}
headers.Set("User-Agent", "existing-ua")
headers.Set("X-Codex-Beta-Features", "existing-beta")
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
}
}
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
}
}
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "sk-test"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
}))
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
if got := req.Header.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := req.Header.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
}))
applyCodexHeaders(req, auth, "oauth-token", true, nil)
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
applyCodexHeaders(req, nil, "oauth-token", true, nil)
if got := req.Header.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
ginCtx.Request.Header = make(http.Header, len(headers))
for key, value := range headers {
ginCtx.Request.Header.Set(key, value)
}
return context.WithValue(context.Background(), "gin", ginCtx)
}
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
} }

View File

@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data)) reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param) out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) { if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
} }
} }
} }
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
} }
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param) segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
} }
}(httpResp, append([]byte(nil), payload...), attemptModel) }(httpResp, append([]byte(nil), payload...), attemptModel)
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 { if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
} }
lastStatus = resp.StatusCode lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...) lastBody = append([]byte(nil), data...)
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData { if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String()) emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}` emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed) emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]` newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array() parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ { for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
} }
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
} }
} }

View File

@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data)) reporter.publish(ctx, parseGeminiUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
} }
// Refresh refreshes the authentication credentials (no-op for Gemini API key). // Refresh refreshes the authentication credentials (no-op for Gemini API key).
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData { if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String()) emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}` emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed) emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]` newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array() parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ { for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw) newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
} }
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
} }
} }

View File

@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data)) reporter.publish(ctx, parseGeminiUsage(data))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
} }
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param) lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
} }
// countTokensWithAPIKey handles token counting using API key credentials. // countTokensWithAPIKey handles token counting using API key credentials.
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int() count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
} }
// vertexCreds extracts project, location and raw service account JSON from auth metadata. // vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
} }
var param any var param any
converted := "" var converted []byte
if useResponses && from.String() == "claude" { if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data) converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else { } else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param) converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
} }
resp = cliproxyexecutor.Response{Payload: []byte(converted)} resp = cliproxyexecutor.Response{Payload: converted}
reporter.ensurePublished(ctx) reporter.ensurePublished(ctx)
return resp, nil return resp, nil
} }
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
} }
} }
var chunks []string var chunks [][]byte
if useResponses && from.String() == "claude" { if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param) chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else { } else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param) chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
} }
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
} }
} }
@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
return true return true
} }
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName) baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
return strings.Contains(baseModel, "codex") return strings.Contains(baseModel, "codex")
} }
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
if info != nil && strings.EqualFold(info.ID, model) {
return info
}
}
return nil
}
func containsEndpoint(endpoints []string, endpoint string) bool {
for _, item := range endpoints {
if item == endpoint {
return true
}
}
return false
}
// flattenAssistantContent converts assistant message content from array format // flattenAssistantContent converts assistant message content from array format
// to a joined string. GitHub Copilot requires assistant content as a string; // to a joined string. GitHub Copilot requires assistant content as a string;
// sending it as an array causes Claude models to re-answer all previous prompts. // sending it as an array causes Claude models to re-answer all previous prompts.
@@ -653,6 +677,7 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
} }
func normalizeGitHubCopilotResponsesInput(body []byte) []byte { func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
body = stripGitHubCopilotResponsesUnsupportedFields(body)
input := gjson.GetBytes(body, "input") input := gjson.GetBytes(body, "input")
if input.Exists() { if input.Exists() {
// If input is already a string or array, keep it as-is. // If input is already a string or array, keep it as-is.
@@ -825,6 +850,12 @@ func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
return body return body
} }
func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
// GitHub Copilot /responses rejects service_tier, so always remove it.
body, _ = sjson.DeleteBytes(body, "service_tier")
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte { func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools") tools := gjson.GetBytes(body, "tools")
if tools.Exists() { if tools.Exists() {
@@ -970,7 +1001,7 @@ type githubCopilotResponsesStreamState struct {
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
} }
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string { func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
root := gjson.ParseBytes(data) root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String()) out, _ = sjson.Set(out, "id", root.Get("id").String())
@@ -1060,10 +1091,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
} else { } else {
out, _ = sjson.Set(out, "stop_reason", "end_turn") out, _ = sjson.Set(out, "stop_reason", "end_turn")
} }
return out return []byte(out)
} }
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string { func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &githubCopilotResponsesStreamState{ *param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1, TextBlockIndex: -1,
@@ -1085,7 +1116,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
} }
event := gjson.GetBytes(payload, "type").String() event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4) results := make([][]byte, 0, 4)
appendResult := func(chunk string) {
results = append(results, []byte(chunk))
}
ensureMessageStart := func() { ensureMessageStart := func() {
if state.MessageStarted { if state.MessageStarted {
return return
@@ -1093,7 +1127,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}` messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String()) messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String()) messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n") appendResult("event: message_start\ndata: " + messageStart + "\n\n")
state.MessageStarted = true state.MessageStarted = true
} }
startTextBlockIfNeeded := func() { startTextBlockIfNeeded := func() {
@@ -1106,7 +1140,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
} }
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex) contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
state.TextBlockStarted = true state.TextBlockStarted = true
} }
stopTextBlockIfNeeded := func() { stopTextBlockIfNeeded := func() {
@@ -1115,7 +1149,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
} }
contentBlockStop := `{"type":"content_block_stop","index":0}` contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex) contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
state.TextBlockStarted = false state.TextBlockStarted = false
state.TextBlockIndex = -1 state.TextBlockIndex = -1
} }
@@ -1145,7 +1179,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex) contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta) contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n") appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
} }
case "response.reasoning_summary_part.added": case "response.reasoning_summary_part.added":
ensureMessageStart() ensureMessageStart()
@@ -1154,7 +1188,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
state.NextContentIndex++ state.NextContentIndex++
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}` thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex) thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n") appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
case "response.reasoning_summary_text.delta": case "response.reasoning_summary_text.delta":
if state.ReasoningActive { if state.ReasoningActive {
delta := gjson.GetBytes(payload, "delta").String() delta := gjson.GetBytes(payload, "delta").String()
@@ -1162,14 +1196,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}` thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex) thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta) thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n") appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
} }
} }
case "response.reasoning_summary_part.done": case "response.reasoning_summary_part.done":
if state.ReasoningActive { if state.ReasoningActive {
thinkingStop := `{"type":"content_block_stop","index":0}` thinkingStop := `{"type":"content_block_stop","index":0}`
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex) thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n") appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
state.ReasoningActive = false state.ReasoningActive = false
} }
case "response.output_item.added": case "response.output_item.added":
@@ -1197,7 +1231,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index) contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID) contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name) contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
case "response.output_item.delta": case "response.output_item.delta":
item := gjson.GetBytes(payload, "item") item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" { if item.Get("type").String() != "function_call" {
@@ -1217,7 +1251,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index) inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial) inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n") appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.function_call_arguments.delta": case "response.function_call_arguments.delta":
// Copilot sends tool call arguments via this event type (not response.output_item.delta). // Copilot sends tool call arguments via this event type (not response.output_item.delta).
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...} // Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
@@ -1234,7 +1268,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index) inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial) inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n") appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.output_item.done": case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" { if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break break
@@ -1245,7 +1279,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
} }
contentBlockStop := `{"type":"content_block_stop","index":0}` contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index) contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
case "response.completed": case "response.completed":
ensureMessageStart() ensureMessageStart()
stopTextBlockIfNeeded() stopTextBlockIfNeeded()
@@ -1269,8 +1303,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
if cachedTokens > 0 { if cachedTokens > 0 {
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens) messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
} }
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n") appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true state.MessageStopSent = true
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
} }
} }
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
t.Parallel()
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
}})
defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) { func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel() t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") { if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
@@ -132,6 +156,19 @@ func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testi
} }
} }
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"user text","service_tier":"default"}`)
got := normalizeGitHubCopilotResponsesInput(body)
if gjson.GetBytes(got, "service_tier").Exists() {
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
}
if gjson.GetBytes(got, "input").String() != "user text" {
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) { func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel() t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`) body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)

View File

@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility. // the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := buildOpenAIUsageJSON(count) usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key. // Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.

View File

@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility. // the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param) doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks { for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)

View File

@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
"cli": "amazonq", "cli": "amazonq",
} }
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
if len(chunk) == 0 {
return
}
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
}
// retryConfig holds configuration for socket retry logic. // retryConfig holds configuration for socket retry logic.
// Based on kiro2Api Python implementation patterns. // Based on kiro2Api Python implementation patterns.
type retryConfig struct { type retryConfig struct {
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// Send tool input as delta // Send tool input as delta
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// Close block // Close block
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
hasToolUses = true hasToolUses = true
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
messageStartSent = true messageStartSent = true
} }
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
lastReportedOutputTokens = currentOutputTokens lastReportedOutputTokens = currentOutputTokens
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex) claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
continue continue
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
// Send thinking delta // Send thinking delta
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
accumulatedThinkingContent.WriteString(thinkingText) accumulatedThinkingContent.WriteString(thinkingText)
} }
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isThinkingBlockOpen = false isThinkingBlockOpen = false
} }
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex) thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
accumulatedThinkingContent.WriteString(processContent) accumulatedThinkingContent.WriteString(processContent)
} }
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex) blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isThinkingBlockOpen = false isThinkingBlockOpen = false
} }
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
// Send text delta // Send text delta
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
// Close text block before entering thinking // Close text block before entering thinking
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isTextBlockOpen = false isTextBlockOpen = false
} }
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex) claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
} }
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isTextBlockOpen = false isTextBlockOpen = false
} }
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// Send input_json_delta with the tool input // Send input_json_delta with the tool input
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
} }
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isTextBlockOpen = false isTextBlockOpen = false
} }
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex) thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// Accumulate for token counting // Accumulate for token counting
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
isTextBlockOpen = false isTextBlockOpen = false
} }
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
if tu.Input != nil { if tu.Input != nil {
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
} }
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
} }
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// Send message_stop event separately // Send message_stop event separately
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
for _, chunk := range sseData { for _, chunk := range sseData {
if chunk != "" { enqueueTranslatedSSE(out, chunk)
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
} }
// reporter.publish is called via defer // reporter.publish is called via defer
} }

View File

@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed // Translate response back to source format when needed
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -205,6 +205,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return nil, err return nil, err
} }
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil { if err != nil {
@@ -286,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Pass through translator; it yields one or more chunks for the target schema. // Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
@@ -326,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
usageJSON := buildOpenAIUsageJSON(count) usageJSON := buildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil return cliproxyexecutor.Response{Payload: translatedUsage}, nil
} }
// Refresh is a no-op for API-key based compatibility providers. // Refresh is a no-op for API-key based compatibility providers.

View File

@@ -2,17 +2,15 @@ package executor
import ( import (
"context" "context"
"net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse // httpClientCache caches HTTP clients by proxy URL to enable connection reuse
@@ -111,45 +109,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
// Returns: // Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid // - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport { func buildProxyTransport(proxyURL string) *http.Transport {
if proxyURL == "" { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil return nil
} }
parsedURL, errParse := url.Parse(proxyURL)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
return nil
}
var transport *http.Transport
// Handle different proxy schemes
if parsedURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil
}
return transport return transport
} }

View File

@@ -0,0 +1,30 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}

View File

@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve // Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility. // the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param) out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()} resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil return resp, nil
} }
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
} }
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param) chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks { for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
} }
} }
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param) doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks { for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])} out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
} }
if errScan := scanner.Err(); errScan != nil { if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan) recordAPIResponseError(ctx, e.cfg, errScan)
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
usageJSON := buildOpenAIUsageJSON(count) usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil return cliproxyexecutor.Response{Payload: translated}, nil
} }
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
return return
} }
r.once.Do(func() { r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{ usage.PublishRecord(ctx, r.buildRecord(detail, failed))
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: failed,
Detail: detail,
})
}) })
} }
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
return return
} }
r.once.Do(func() { r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{ usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: false,
Detail: usage.Detail{},
})
}) })
} }
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
if r == nil {
return usage.Record{Detail: detail, Failed: failed}
}
return usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Latency: r.latency(),
Failed: failed,
Detail: detail,
}
}
func (r *usageReporter) latency() time.Duration {
if r == nil || r.requestedAt.IsZero() {
return 0
}
latency := time.Since(r.requestedAt)
if latency < 0 {
return 0
}
return latency
}
func apiKeyFromContext(ctx context.Context) string { func apiKeyFromContext(ctx context.Context) string {
if ctx == nil { if ctx == nil {
return "" return ""

View File

@@ -1,6 +1,11 @@
package executor package executor
import "testing" import (
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestParseOpenAIUsageChatCompletions(t *testing.T) { func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`) data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9) t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
} }
} }
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
reporter := &usageReporter{
provider: "openai",
model: "gpt-5.4",
requestedAt: time.Now().Add(-1500 * time.Millisecond),
}
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.Latency < time.Second {
t.Fatalf("latency = %v, want >= 1s", record.Latency)
}
if record.Latency > 3*time.Second {
t.Fatalf("latency = %v, want <= 3s", record.Latency)
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
rawJSON := inputRawJSON rawJSON := inputRawJSON
// system instruction // system instruction
systemInstructionJSON := "" var systemInstructionJSON []byte
hasSystemInstruction := false hasSystemInstruction := false
systemResult := gjson.GetBytes(rawJSON, "system") systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() { if systemResult.IsArray() {
systemResults := systemResult.Array() systemResults := systemResult.Array()
systemInstructionJSON = `{"role":"user","parts":[]}` systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
for i := 0; i < len(systemResults); i++ { for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i] systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type") systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String() systemPrompt := systemPromptResult.Get("text").String()
partJSON := `{}` partJSON := []byte(`{}`)
if systemPrompt != "" { if systemPrompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt) partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
} }
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON) systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
hasSystemInstruction = true hasSystemInstruction = true
} }
} }
} else if systemResult.Type == gjson.String { } else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}` systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String()) systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true hasSystemInstruction = true
} }
// contents // contents
contentsJSON := "[]" contentsJSON := []byte(`[]`)
hasContents := false hasContents := false
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
toolNameByID := make(map[string]string)
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() { if messagesResult.IsArray() {
messageResults := messagesResult.Array() messageResults := messagesResult.Array()
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
clientContentJSON := `{"role":"","parts":[]}` clientContentJSON := []byte(`{"role":"","parts":[]}`)
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role) clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentResults := contentsResult.Array()
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
// Valid signature, send as thought block // Valid signature, send as thought block
partJSON := `{}` // Always include "text" field — Google Antigravity API requires it
partJSON, _ = sjson.Set(partJSON, "thought", true) // even for redacted thinking where the text is empty.
if thinkingText != "" { partJSON := []byte(`{}`)
partJSON, _ = sjson.Set(partJSON, "text", thinkingText) partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
} partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
if signature != "" { if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature) partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
} }
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String() prompt := contentResult.Get("text").String()
// Skip empty text parts to avoid Gemini API error: // Skip empty text parts to avoid Gemini API error:
@@ -159,17 +164,21 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if prompt == "" { if prompt == "" {
continue continue
} }
partJSON := `{}` partJSON := []byte(`{}`)
partJSON, _ = sjson.Set(partJSON, "text", prompt) partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here. // NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected. // Antigravity API validates signatures, so dummy values are rejected.
functionName := contentResult.Get("name").String() functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
argsResult := contentResult.Get("input") argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String() functionID := contentResult.Get("id").String()
if functionID != "" && functionName != "" {
toolNameByID[functionID] = functionName
}
// Handle both object and string input formats // Handle both object and string input formats
var argsRaw string var argsRaw string
if argsResult.IsObject() { if argsResult.IsObject() {
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
if argsRaw != "" { if argsRaw != "" {
partJSON := `{}` partJSON := []byte(`{}`)
// Use skip_thought_signature_validator for tool calls without valid thinking signature // Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini // This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API // and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator" const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) { if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature) partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else { } else {
// No valid signature - use skip sentinel to bypass validation // No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel) partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
} }
if functionID != "" { if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
} }
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName) partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw) partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID != "" {
funcName := toolCallID funcName, ok := toolNameByID[toolCallID]
toolCallIDs := strings.Split(toolCallID, "-") if !ok {
if len(toolCallIDs) > 1 { // Fallback: derive a semantic name from the ID by stripping
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-") // the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
// Only use the raw ID as a last resort when the heuristic produces an empty string.
parts := strings.Split(toolCallID, "-")
if len(parts) > 2 {
funcName = strings.Join(parts[:len(parts)-2], "-")
}
if funcName == "" {
funcName = toolCallID
}
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
} }
functionResponseResult := contentResult.Get("content") functionResponseResult := contentResult.Get("content")
functionResponseJSON := `{}` functionResponseJSON := []byte(`{}`)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID) functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName) functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
responseData := "" responseData := ""
if functionResponseResult.Type == gjson.String { if functionResponseResult.Type == gjson.String {
responseData = functionResponseResult.String() responseData = functionResponseResult.String()
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData) functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() { } else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array() frResults := functionResponseResult.Array()
nonImageCount := 0 nonImageCount := 0
lastNonImageRaw := "" lastNonImageRaw := ""
filteredJSON := "[]" filteredJSON := []byte(`[]`)
imagePartsJSON := "[]" imagePartsJSON := []byte(`[]`)
for _, fr := range frResults { for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" { if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}` inlineDataJSON := []byte(`{}`)
if mimeType := fr.Get("source.media_type").String(); mimeType != "" { if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
} }
if data := fr.Get("source.data").String(); data != "" { if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
} }
imagePartJSON := `{}` imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON) imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON) imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
continue continue
} }
nonImageCount++ nonImageCount++
lastNonImageRaw = fr.Raw lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw) filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
} }
if nonImageCount == 1 { if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
} else if nonImageCount > 1 { } else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
} else { } else {
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "") functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
} }
// Place image data inside functionResponse.parts as inlineData // Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid // instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context. // base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 { if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
} }
} else if functionResponseResult.IsObject() { } else if functionResponseResult.IsObject() {
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" { if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}` inlineDataJSON := []byte(`{}`)
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" { if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
} }
if data := functionResponseResult.Get("source.data").String(); data != "" { if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
} }
imagePartJSON := `{}` imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON) imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]" imagePartsJSON := []byte(`[]`)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON) imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "") functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
} else { } else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
} }
} else if functionResponseResult.Raw != "" { } else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
} else { } else {
// Content field is missing entirely — .Raw is empty which // Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}). // causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "") functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
} }
partJSON := `{}` partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON) partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
sourceResult := contentResult.Get("source") sourceResult := contentResult.Get("source")
if sourceResult.Get("type").String() == "base64" { if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}` inlineDataJSON := []byte(`{}`)
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" { if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
} }
if data := sourceResult.Get("data").String(); data != "" { if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
} }
partJSON := `{}` partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON) partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} }
} }
} }
// Reorder parts for 'model' role to ensure thinking block is first // Reorder parts for 'model' role to ensure thinking block is first
if role == "model" { if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts") partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() { if partsResult.IsArray() {
parts := partsResult.Array() parts := partsResult.Array()
var thinkingParts []gjson.Result var thinkingParts []gjson.Result
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, p := range otherParts { for _, p := range otherParts {
newParts = append(newParts, p.Value()) newParts = append(newParts, p.Value())
} }
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts) clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
} }
} }
} }
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Skip messages with empty parts array to avoid Gemini API error: // Skip messages with empty parts array to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field" // "required oneof field 'data' must have one initialized field"
partsCheck := gjson.Get(clientContentJSON, "parts") partsCheck := gjson.GetBytes(clientContentJSON, "parts")
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 { if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
continue continue
} }
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true hasContents = true
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
prompt := contentsResult.String() prompt := contentsResult.String()
partJSON := `{}` partJSON := []byte(`{}`)
if prompt != "" { if prompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt) partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
} }
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true hasContents = true
} }
} }
} }
// tools // tools
toolsJSON := "" var toolsJSON []byte
toolDeclCount := 0 toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"} allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools") toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() { if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]` toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
toolsResults := toolsResult.Array() toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ { for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i] toolResult := toolsResults[i]
@@ -378,23 +396,24 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
// Sanitize the input schema for Antigravity API compatibility // Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw) inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
for toolKey := range gjson.Parse(tool).Map() { tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
for toolKey := range gjson.ParseBytes(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) { if util.InArray(allowedToolKeys, toolKey) {
continue continue
} }
tool, _ = sjson.Delete(tool, toolKey) tool, _ = sjson.DeleteBytes(tool, toolKey)
} }
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool) toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++ toolDeclCount++
} }
} }
} }
// Build output Gemini CLI request JSON // Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}` out := []byte(`{"model":"","request":{"contents":[]}}`)
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active // Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0 hasTools := toolDeclCount > 0
@@ -408,27 +427,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if hasSystemInstruction { if hasSystemInstruction {
// Append hint as a new part to existing system instruction // Append hint as a new part to existing system instruction
hintPart := `{"text":""}` hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint) hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart) systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
} else { } else {
// Create new system instruction with hint // Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}` systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
hintPart := `{"text":""}` hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint) hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart) systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true hasSystemInstruction = true
} }
} }
if hasSystemInstruction { if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON) out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
} }
if hasContents { if hasContents {
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON) out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
} }
if toolDeclCount > 0 { if toolDeclCount > 0 {
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON) out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
} }
// tool_choice // tool_choice
@@ -445,15 +464,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
switch toolChoiceType { switch toolChoiceType {
case "auto": case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none": case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any": case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool": case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" { if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
} }
} }
} }
@@ -464,8 +483,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
case "enabled": case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int()) budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
} }
case "adaptive", "auto": case "adaptive", "auto":
// For adaptive thinking: // For adaptive thinking:
@@ -477,28 +496,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String())) effort = strings.ToLower(strings.TrimSpace(v.String()))
} }
if effort != "" { if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else { } else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
} }
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
} }
} }
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
} }
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
} }
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
} }
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
} }
outBytes := []byte(out) out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
return outBytes return out
} }

View File

@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{ inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620", "model": "claude-3-5-sonnet-20240620",
"messages": [ "messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "get_weather-call-123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{ {
"role": "user", "role": "user",
"content": [ "content": [
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
outputStr := string(output) outputStr := string(output)
// Check function response conversion // Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() { if !funcResp.Exists() {
t.Error("functionResponse should exist") t.Error("functionResponse should exist")
} }
if funcResp.Get("id").String() != "get_weather-call-123" { if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String()) t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
} }
if funcResp.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"name": "Glob",
"input": {"pattern": "**/*.py"}
},
{
"type": "tool_use",
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"name": "Bash",
"input": {"command": "ls"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "file1.py\nfile2.py"
},
{
"type": "tool_result",
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"content": "total 10"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp0.Exists() {
t.Fatal("first functionResponse should exist")
}
if got := funcResp0.Get("name").String(); got != "Glob" {
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
}
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
if !funcResp1.Exists() {
t.Fatal("second functionResponse should exist")
}
if got := funcResp1.Get("name").String(); got != "Bash" {
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "Read-1773420180464065165-1327",
"name": "Read",
"input": {"file_path": "/tmp/test.py"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-1773420180464065165-1327",
"content": "file content here"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "Read" {
t.Errorf("Expected name 'Read', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "get_weather" {
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "result data"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
got := funcResp.Get("name").String()
if got == "" {
t.Error("functionResponse.name must not be empty")
}
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
}
} }
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) { func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -43,6 +44,10 @@ type Params struct {
// Signature caching support // Signature caching support
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
// Reverse map: sanitized Gemini function name → original Claude tool name.
// Populated lazily on the first response chunk from the original request JSON.
ToolNameMap map[string]string
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -63,13 +68,14 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response // - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &Params{ *param = &Params{
HasFirstResponse: false, HasFirstResponse: false,
ResponseType: 0, ResponseType: 0,
ResponseIndex: 0, ResponseIndex: 0,
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
} }
} }
modelName := gjson.GetBytes(requestRawJSON, "model").String() modelName := gjson.GetBytes(requestRawJSON, "model").String()
@@ -77,44 +83,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params := (*param).(*Params) params := (*param).(*Params)
if bytes.Equal(rawJSON, []byte("[DONE]")) { if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := "" output := make([]byte, 0, 256)
// Only send final events if we have actually output content // Only send final events if we have actually output content
if params.HasContent { if params.HasContent {
appendFinalEvents(params, &output, true) appendFinalEvents(params, &output, true)
return []string{ output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n", return [][]byte{output}
}
} }
return []string{} return [][]byte{}
} }
output := "" output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event // Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session // This is only sent for the very first response chunk to establish the streaming session
if !params.HasFirstResponse { if !params.HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification // Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization // This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
// Use cpaUsageMetadata within the message_start event for Claude. // Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() { if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
} }
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() { if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
} }
// Override default values with actual response metadata if available from the Gemini CLI response // Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
} }
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
} }
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) appendEvent("message_start", string(messageStartTemplate))
params.HasFirstResponse = true params.HasFirstResponse = true
} }
@@ -144,15 +150,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params.CurrentThinkingText.Reset() params.CurrentThinkingText.Reset()
} }
output = output + "event: content_block_delta\n" data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String())) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state } else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String()) params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n" data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String()) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true params.HasContent = true
} else { } else {
// Transition from another state to thinking // Transition from another state to thinking
@@ -163,19 +167,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex) // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n" // output = output + "\n\n\n"
} }
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++ params.ResponseIndex++
} }
// Start a new thinking content block // Start a new thinking content block
output = output + "event: content_block_start\n" appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex) data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
output = output + "\n\n\n" appendEvent("content_block_delta", string(data))
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 2 // Set state to thinking params.ResponseType = 2 // Set state to thinking
params.HasContent = true params.HasContent = true
// Start accumulating thinking text for signature caching // Start accumulating thinking text for signature caching
@@ -188,9 +187,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process regular text content (user-visible output) // Process regular text content (user-visible output)
// Continue existing text block if already in content state // Continue existing text block if already in content state
if params.ResponseType == 1 { if params.ResponseType == 1 {
output = output + "event: content_block_delta\n" data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String()) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true params.HasContent = true
} else { } else {
// Transition from another state to text content // Transition from another state to text content
@@ -201,19 +199,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex) // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n" // output = output + "\n\n\n"
} }
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++ params.ResponseIndex++
} }
if partTextResult.String() != "" { if partTextResult.String() != "" {
// Start a new text content block // Start a new text content block
output = output + "event: content_block_start\n" appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex) data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
output = output + "\n\n\n" appendEvent("content_block_delta", string(data))
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content params.ResponseType = 1 // Set state to content
params.HasContent = true params.HasContent = true
} }
@@ -224,14 +217,12 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Handle function/tool calls from the AI model // Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility // This processes tool usage requests and formats them for Claude Code API compatibility
params.HasToolUse = true params.HasToolUse = true
fcName := functionCallResult.Get("name").String() fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
// Handle state transitions when switching to function calls // Handle state transitions when switching to function calls
// Close any existing function call block first // Close any existing function call block first
if params.ResponseType == 3 { if params.ResponseType == 3 {
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++ params.ResponseIndex++
params.ResponseType = 0 params.ResponseType = 0
} }
@@ -245,26 +236,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Close any other existing content block // Close any other existing content block
if params.ResponseType != 0 { if params.ResponseType != 0 {
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++ params.ResponseIndex++
} }
// Start a new tool use content block // Start a new tool use content block
// This creates the structure for a function call in Claude Code format // This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details // Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex) data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))) data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName) data, _ = sjson.SetBytes(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data) appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n" data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} }
params.ResponseType = 3 params.ResponseType = 3
params.HasContent = true params.HasContent = true
@@ -296,10 +282,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
appendFinalEvents(params, &output, false) appendFinalEvents(params, &output, false)
} }
return []string{output} return [][]byte{output}
} }
func appendFinalEvents(params *Params, output *string, force bool) { func appendFinalEvents(params *Params, output *[]byte, force bool) {
if params.HasSentFinalEvents { if params.HasSentFinalEvents {
return return
} }
@@ -314,9 +300,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
} }
if params.ResponseType != 0 { if params.ResponseType != 0 {
*output = *output + "event: content_block_stop\n" *output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
*output = *output + "\n\n\n"
params.ResponseType = 0 params.ResponseType = 0
} }
@@ -329,18 +313,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
} }
} }
*output = *output + "event: message_delta\n" delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 { if params.CachedTokenCount > 0 {
var err error var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount) delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil { if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
} }
} }
*output = *output + delta + "\n\n\n" *output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
params.HasSentFinalEvents = true params.HasSentFinalEvents = true
} }
@@ -369,9 +351,9 @@ func resolveStopReason(params *Params) string {
// - param: A pointer to a parameter object for the conversion. // - param: A pointer to a parameter object for the conversion.
// //
// Returns: // Returns:
// - string: A Claude-compatible JSON response. // - []byte: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
modelName := gjson.GetBytes(requestRawJSON, "model").String() modelName := gjson.GetBytes(requestRawJSON, "model").String()
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -388,15 +370,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
} }
} }
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String()) responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens) responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 { if cachedTokens > 0 {
var err error var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens) responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil { if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
} }
@@ -407,7 +389,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
if contentArrayInitialized { if contentArrayInitialized {
return return
} }
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]") responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
contentArrayInitialized = true contentArrayInitialized = true
} }
@@ -423,9 +405,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return return
} }
ensureContentArray() ensureContentArray()
block := `{"type":"text","text":""}` block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.Set(block, "text", textBuilder.String()) block, _ = sjson.SetBytes(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block) responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
textBuilder.Reset() textBuilder.Reset()
} }
@@ -434,12 +416,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return return
} }
ensureContentArray() ensureContentArray()
block := `{"type":"thinking","thinking":""}` block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String()) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if thinkingSignature != "" { if thinkingSignature != "" {
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature)) block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
} }
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block) responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
thinkingBuilder.Reset() thinkingBuilder.Reset()
thinkingSignature = "" thinkingSignature = ""
} }
@@ -473,18 +455,18 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
flushText() flushText()
hasToolCall = true hasToolCall = true
name := functionCall.Get("name").String() name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
toolIDCounter++ toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() { if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw) toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
} }
ensureContentArray() ensureContentArray()
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock) responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
continue continue
} }
} }
@@ -508,17 +490,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
} }
} }
} }
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason) responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
if promptTokens == 0 && outputTokens == 0 { if promptTokens == 0 && outputTokens == 0 {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() { if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
responseJSON, _ = sjson.Delete(responseJSON, "usage") responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
} }
} }
return responseJSON return responseJSON
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"input_tokens":%d}`, count) return translatorcommon.ClaudeInputTokensJSON(count)
} }

View File

@@ -34,10 +34,10 @@ import (
// - []byte: The transformed request data in Gemini API format // - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
template := "" template := `{"project":"","request":{},"model":""}`
template = `{"project":"","request":{},"model":""}` templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
template, _ = sjson.SetRaw(template, "request", string(rawJSON)) templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
template, _ = sjson.Set(template, "model", modelName) template = string(templateBytes)
template, _ = sjson.Delete(template, "request.model") template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template) template, errFixCLIToolResponse := fixCLIToolResponse(template)
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
systemInstructionResult := gjson.Get(template, "request.system_instruction") systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() { if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
template = string(templateBytes)
template, _ = sjson.Delete(template, "request.system_instruction") template, _ = sjson.Delete(template, "request.system_instruction")
} }
rawJSON = []byte(template) rawJSON = []byte(template)
@@ -138,30 +139,47 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ResponsesNeeded int ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
} }
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string. // parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to a minimal "functionResponse" object when parsing fails. // Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponseRaw(response gjson.Result) string { // fallbackName is used when the response's own name is empty.
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
if response.IsObject() && gjson.Valid(response.Raw) { if response.IsObject() && gjson.Valid(response.Raw) {
return response.Raw raw := response.Raw
name := response.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
raw = string(updated)
}
return raw
} }
log.Debugf("parse function response failed, using fallback") log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse") funcResp := response.Get("functionResponse")
if funcResp.Exists() { if funcResp.Exists() {
fr := `{"functionResponse":{"name":"","response":{"result":""}}}` fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String()) name := funcResp.Get("name").String()
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String()) if strings.TrimSpace(name) == "" {
if id := funcResp.Get("id").String(); id != "" { name = fallbackName
fr, _ = sjson.Set(fr, "functionResponse.id", id)
} }
return fr fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
}
return string(fr)
} }
fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}` useName := fallbackName
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String()) if useName == "" {
return fr useName = "unknown"
}
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
return string(fr)
} }
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping. // fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -188,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}` contentsWrapper := []byte(`{"contents":[]}`)
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -211,30 +229,26 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 { if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...) collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied // Check if pending groups can be satisfied (FIFO: oldest group first)
for i := len(pendingGroups) - 1; i >= 0; i-- { for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[i] group := pendingGroups[0]
if len(collectedResponses) >= group.ResponsesNeeded { pendingGroups = pendingGroups[1:]
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Take the needed responses for this group
functionResponseContent := `{"parts":[],"role":"function"}` groupResponses := collectedResponses[:group.ResponsesNeeded]
for _, response := range groupResponses { collectedResponses = collectedResponses[group.ResponsesNeeded:]
partRaw := parseFunctionResponseRaw(response)
if partRaw != "" { // Create merged function response content
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
} for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
} }
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
} }
} }
@@ -243,25 +257,26 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
functionCallsCount := 0 var callNames []string
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsCount++ callNames = append(callNames, part.Get("functionCall.name").String())
} }
return true return true
}) })
if functionCallsCount > 0 { if len(callNames) > 0 {
// Add the model content // Add the model content
if !value.IsObject() { if !value.IsObject() {
log.Warnf("failed to parse model content") log.Warnf("failed to parse model content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount, ResponsesNeeded: len(callNames),
CallNames: callNames,
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
@@ -270,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content") log.Warnf("failed to parse content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
@@ -278,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content") log.Warnf("failed to parse content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
} }
return true return true
@@ -290,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}` functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for _, response := range groupResponses { for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response) partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" { if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw) functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
} }
} }
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
return result, nil return string(result), nil
} }

View File

@@ -171,3 +171,257 @@ func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String()) t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
} }
} }
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
// When the Amp client sends functionResponse with an empty name,
// fixCLIToolResponse should backfill it from the corresponding functionCall.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
// Parallel function calls: both responses have empty names.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
parts := funcContent.Get("parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
}
name0 := parts[0].Get("functionResponse.name").String()
name1 := parts[1].Get("functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first response name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
}
}
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
// When functionResponse already has a valid name, it should be preserved.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
// If there are more function responses than calls, unmatched extras are discarded by grouping.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// First response should be backfilled from the call
name0 := funcContent.Get("parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
}
}
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
// Two sequential function call groups should be matched FIFO.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContents []gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContents = append(funcContents, c)
}
}
if len(funcContents) != 2 {
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
}
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}

View File

@@ -8,8 +8,8 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -29,8 +29,8 @@ import (
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - []string: The transformed request data in Gemini API format // - [][]byte: The transformed response data in Gemini API format.
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string { func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) { if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
} }
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
chunk = restoreUsageMetadata(chunk) chunk = restoreUsageMetadata(chunk)
} }
} else { } else {
chunkTemplate := "[]" chunkTemplate := []byte("[]")
responseResult := gjson.ParseBytes(chunk) responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() { if responseResult.IsArray() {
responseResultItems := responseResult.Array() responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ { for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i] responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() { if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
} }
} }
} }
chunk = []byte(chunkTemplate) chunk = chunkTemplate
} }
return []string{string(chunk)} return [][]byte{chunk}
} }
return []string{} return [][]byte{}
} }
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response. // ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response containing the response data // - []byte: A Gemini-compatible JSON response containing the response data.
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response") responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() { if responseResult.Exists() {
chunk := restoreUsageMetadata([]byte(responseResult.Raw)) chunk := restoreUsageMetadata([]byte(responseResult.Raw))
return string(chunk) return chunk
} }
return string(rawJSON) return rawJSON
} }
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return translatorcommon.GeminiTokenCountJSON(count)
} }
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata. // restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.

View File

@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil) result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
if result != tt.expected { if string(result) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected) t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
} }
}) })
} }
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
if len(results) != 1 { if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results)) t.Fatalf("expected 1 result, got %d", len(results))
} }
if results[0] != tt.expected { if string(results[0]) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected) t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
} }
}) })
} }

View File

@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
continue continue
} }
fid := tc.Get("id").String() fid := tc.Get("id").String()
fname := tc.Get("function.name").String() fname := util.SanitizeFunctionName(tc.Get("function.name").String())
fargs := tc.Get("function.arguments").String() fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, fid := range fIDs { for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok { if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid) toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name) toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
resp := toolResponses[fid] resp := toolResponses[fid]
if resp == "" { if resp == "" {
resp = "{}" resp = "{}"
@@ -354,33 +354,39 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if errRename != nil { if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename) log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw = string(fnRawBytes)
} else { } else {
fnRaw = renamed fnRaw = renamed
} }
} else { } else {
var errSet error var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw = string(fnRawBytes)
} }
fnRaw, _ = sjson.Delete(fnRaw, "strict") fnRawBytes := []byte(fnRaw)
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
if !hasFunction { if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
} }

View File

@@ -13,6 +13,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
FunctionIndex int FunctionIndex int
SawToolCall bool // Tracks if any tool call was seen in the entire stream SawToolCall bool // Tracks if any tool call was seen in the entire stream
UpstreamFinishReason string // Caches the upstream finish reason for final chunk UpstreamFinishReason string // Caches the upstream finish reason for final chunk
SanitizedNameMap map[string]string
} }
// functionCallIDCounter provides a process-wide unique counter for function call identifiers. // functionCallIDCounter provides a process-wide unique counter for function call identifiers.
@@ -44,25 +46,29 @@ var functionCallIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response // - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{ *param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0, UnixTimestamp: 0,
FunctionIndex: 0, FunctionIndex: 0,
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
} }
} }
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
}
if bytes.Equal(rawJSON, []byte("[DONE]")) { if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{} return [][]byte{}
} }
// Initialize the OpenAI SSE template. // Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version. // Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String()) template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
} }
// Extract and set the creation timestamp. // Extract and set the creation timestamp.
@@ -71,14 +77,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if err == nil { if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix() (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
} }
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp) template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else { } else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp) template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} }
// Extract and set the response ID. // Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String()) template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
} }
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk) // Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
@@ -90,21 +96,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
} }
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
} }
// Include cached token count if present (indicates prompt caching is working) // Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 { if cachedTokenCount > 0 {
var err error var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil { if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err) log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
} }
@@ -141,33 +147,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Handle text content, distinguishing between regular content and reasoning/thoughts. // Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() { if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent) template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
} else { } else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent) template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
} }
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() { } else if functionCallResult.Exists() {
// Handle function call content. // Handle function call content.
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks (*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls") toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++ (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() { if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array()) functionCallIndex = len(toolCallsResult.Array())
} else { } else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
} }
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}` functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
fcName := functionCallResult.Get("name").String() fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
} }
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() { } else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String() data := inlineDataResult.Get("data").String()
if data == "" { if data == "" {
@@ -181,16 +187,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array()) imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}` imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex) imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL) imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload) template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }
@@ -212,11 +218,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
} else { } else {
finishReason = "stop" finishReason = "stop"
} }
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason) template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason)) template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
} }
return []string{template} return [][]byte{template}
} }
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response. // ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
@@ -231,11 +237,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// - param: A pointer to a parameter object for the conversion // - param: A pointer to a parameter object for the conversion
// //
// Returns: // Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata // - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response") responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() { if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param) return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
} }
return "" return []byte{}
} }

View File

@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result1) != 1 { if len(result1) != 1 {
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1)) t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
} }
fr1 := gjson.Get(result1[0], "choices.0.finish_reason") fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" { if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String()) t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
} }
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result2) != 1 { if len(result2) != 1 {
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2)) t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
} }
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String() fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr2 != "tool_calls" { if fr2 != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2) t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
} }
// Verify native_finish_reason is lowercase upstream value // Verify native_finish_reason is lowercase upstream value
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String() nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
if nfr2 != "stop" { if nfr2 != "stop" {
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2) t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
} }
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param) result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "stop" (no tool calls were made) // Verify finish_reason is "stop" (no tool calls were made)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String() fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "stop" { if fr != "stop" {
t.Errorf("Expected finish_reason 'stop', got: %s", fr) t.Errorf("Expected finish_reason 'stop', got: %s", fr)
} }
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param) result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "max_tokens" // Verify finish_reason is "max_tokens"
fr := gjson.Get(result2[0], "choices.0.finish_reason").String() fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "max_tokens" { if fr != "max_tokens" {
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr) t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
} }
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param) result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "tool_calls" (takes priority over max_tokens) // Verify finish_reason is "tool_calls" (takes priority over max_tokens)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String() fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "tool_calls" { if fr != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr) t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
} }
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param) result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param)
// Verify no finish_reason on intermediate chunk // Verify no finish_reason on intermediate chunk
fr1 := gjson.Get(result1[0], "choices.0.finish_reason") fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" { if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1) t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
} }
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param) result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify no finish_reason on intermediate chunk // Verify no finish_reason on intermediate chunk
fr2 := gjson.Get(result2[0], "choices.0.finish_reason") fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" { if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2) t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
responseResult := gjson.GetBytes(rawJSON, "response") responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() { if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw) rawJSON = []byte(responseResult.Raw)
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
} }
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response") responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() { if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw) rawJSON = []byte(responseResult.Raw)

View File

@@ -8,7 +8,7 @@ import (
"context" "context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
"github.com/tidwall/sjson" translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
) )
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format. // ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
@@ -23,15 +23,13 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object // - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap each converted response in a "response" object to match Gemini CLI API structure // Wrap each converted response in a "response" object to match Gemini CLI API structure
newOutputs := make([]string, 0) newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ { for i := 0; i < len(outputs); i++ {
json := `{"response": {}}` newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
} }
return newOutputs return newOutputs
} }
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
// - param: A pointer to a parameter object for the conversion // - param: A pointer to a parameter object for the conversion
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object // - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap the converted response in a "response" object to match Gemini CLI API structure // Wrap the converted response in a "response" object to match Gemini CLI API structure
json := `{"response": {}}` return translatorcommon.WrapGeminiCLIResponse(out)
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
} }
func GeminiCLITokenCount(ctx context.Context, count int64) string { func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return GeminiTokenCount(ctx, count) return GeminiTokenCount(ctx, count)
} }

View File

@@ -63,7 +63,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session) userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload // Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID) out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -87,20 +87,20 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
var pendingToolIDs []string var pendingToolIDs []string
// Model mapping to specify which Claude Code model to use // Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// Generation config extraction from Gemini format // Generation config extraction from Gemini format
if genConfig := root.Get("generationConfig"); genConfig.Exists() { if genConfig := root.Get("generationConfig"); genConfig.Exists() {
// Max output tokens configuration // Max output tokens configuration
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() { if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int()) out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
} }
// Temperature setting for controlling response randomness // Temperature setting for controlling response randomness
if temp := genConfig.Get("temperature"); temp.Exists() { if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float()) out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := genConfig.Get("topP"); topP.Exists() { } else if topP := genConfig.Get("topP"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set) // Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float()) out, _ = sjson.SetBytes(out, "top_p", topP.Float())
} }
// Stop sequences configuration for custom termination conditions // Stop sequences configuration for custom termination conditions
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() { if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
@@ -110,7 +110,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
return true return true
}) })
if len(stopSequences) > 0 { if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences) out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
} }
} }
// Include thoughts configuration for reasoning process visibility // Include thoughts configuration for reasoning process visibility
@@ -132,30 +132,30 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
switch level { switch level {
case "": case "":
case "none": case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
default: default:
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok { if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
level = mapped level = mapped
} }
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level) out, _ = sjson.SetBytes(out, "output_config.effort", level)
} }
} else { } else {
switch level { switch level {
case "": case "":
case "none": case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case "auto": case "auto":
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default: default:
if budget, ok := thinking.ConvertLevelToBudget(level); ok { if budget, ok := thinking.ConvertLevelToBudget(level); ok {
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget) out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
} }
} }
} }
@@ -169,37 +169,37 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive { if supportsAdaptive {
switch budget { switch budget {
case 0: case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
default: default:
level, ok := thinking.ConvertBudgetToLevel(budget) level, ok := thinking.ConvertBudgetToLevel(budget)
if ok { if ok {
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM { if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
level = mapped level = mapped
} }
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level) out, _ = sjson.SetBytes(out, "output_config.effort", level)
} }
} }
} else { } else {
switch budget { switch budget {
case 0: case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case -1: case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default: default:
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget) out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
} }
} }
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True { } else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True { } else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
} }
} }
} }
@@ -220,9 +220,9 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}) })
if systemText.Len() > 0 { if systemText.Len() > 0 {
// Create system message in Claude Code format // Create system message in Claude Code format
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}` systemMessage := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String()) systemMessage, _ = sjson.SetBytes(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage) out, _ = sjson.SetRawBytes(out, "messages.-1", systemMessage)
} }
} }
} }
@@ -245,42 +245,42 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
} }
// Create message structure in Claude Code format // Create message structure in Claude Code format
msg := `{"role":"","content":[]}` msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() { if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
// Text content conversion // Text content conversion
if text := part.Get("text"); text.Exists() { if text := part.Get("text"); text.Exists() {
textContent := `{"type":"text","text":""}` textContent := []byte(`{"type":"text","text":""}`)
textContent, _ = sjson.Set(textContent, "text", text.String()) textContent, _ = sjson.SetBytes(textContent, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", textContent) msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true return true
} }
// Function call (from model/assistant) conversion to tool use // Function call (from model/assistant) conversion to tool use
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" { if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
// Generate a unique tool ID and enqueue it for later matching // Generate a unique tool ID and enqueue it for later matching
// with the corresponding functionResponse // with the corresponding functionResponse
toolID := genToolCallID() toolID := genToolCallID()
pendingToolIDs = append(pendingToolIDs, toolID) pendingToolIDs = append(pendingToolIDs, toolID)
toolUse, _ = sjson.Set(toolUse, "id", toolID) toolUse, _ = sjson.SetBytes(toolUse, "id", toolID)
if name := fc.Get("name"); name.Exists() { if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String()) toolUse, _ = sjson.SetBytes(toolUse, "name", name.String())
} }
if args := fc.Get("args"); args.Exists() && args.IsObject() { if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw) toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(args.Raw))
} }
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse) msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
return true return true
} }
// Function response (from user) conversion to tool result // Function response (from user) conversion to tool result
if fr := part.Get("functionResponse"); fr.Exists() { if fr := part.Get("functionResponse"); fr.Exists() {
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}` toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
// Attach the oldest queued tool_id to pair the response // Attach the oldest queued tool_id to pair the response
// with its call. If the queue is empty, generate a new id. // with its call. If the queue is empty, generate a new id.
@@ -293,41 +293,41 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
// Fallback: generate new ID if no pending tool_use found // Fallback: generate new ID if no pending tool_use found
toolID = genToolCallID() toolID = genToolCallID()
} }
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID) toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", toolID)
// Extract result content from the function response // Extract result content from the function response
if result := fr.Get("response.result"); result.Exists() { if result := fr.Get("response.result"); result.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", result.String()) toolResult, _ = sjson.SetBytes(toolResult, "content", result.String())
} else if response := fr.Get("response"); response.Exists() { } else if response := fr.Get("response"); response.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", response.Raw) toolResult, _ = sjson.SetBytes(toolResult, "content", response.Raw)
} }
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult) msg, _ = sjson.SetRawBytes(msg, "content.-1", toolResult)
return true return true
} }
// Image content (inline_data) conversion to Claude Code format // Image content (inline_data) conversion to Claude Code format
if inlineData := part.Get("inline_data"); inlineData.Exists() { if inlineData := part.Get("inline_data"); inlineData.Exists() {
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}` imageContent := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() { if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String()) imageContent, _ = sjson.SetBytes(imageContent, "source.media_type", mimeType.String())
} }
if data := inlineData.Get("data"); data.Exists() { if data := inlineData.Get("data"); data.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.data", data.String()) imageContent, _ = sjson.SetBytes(imageContent, "source.data", data.String())
} }
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent) msg, _ = sjson.SetRawBytes(msg, "content.-1", imageContent)
return true return true
} }
// File data conversion to text content with file info // File data conversion to text content with file info
if fileData := part.Get("file_data"); fileData.Exists() { if fileData := part.Get("file_data"); fileData.Exists() {
// For file data, we'll convert to text content with file info // For file data, we'll convert to text content with file info
textContent := `{"type":"text","text":""}` textContent := []byte(`{"type":"text","text":""}`)
fileInfo := "File: " + fileData.Get("file_uri").String() fileInfo := "File: " + fileData.Get("file_uri").String()
if mimeType := fileData.Get("mime_type"); mimeType.Exists() { if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
fileInfo += " (Type: " + mimeType.String() + ")" fileInfo += " (Type: " + mimeType.String() + ")"
} }
textContent, _ = sjson.Set(textContent, "text", fileInfo) textContent, _ = sjson.SetBytes(textContent, "text", fileInfo)
msg, _ = sjson.SetRaw(msg, "content.-1", textContent) msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true return true
} }
@@ -336,8 +336,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
} }
// Only add message if it has content // Only add message if it has content
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 { if contentArray := gjson.GetBytes(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRaw(out, "messages.-1", msg) out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
} }
return true return true
@@ -351,29 +351,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() { if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool { funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
anthropicTool := `{"name":"","description":"","input_schema":{}}` anthropicTool := []byte(`{"name":"","description":"","input_schema":{}}`)
if name := funcDecl.Get("name"); name.Exists() { if name := funcDecl.Get("name"); name.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String()) anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", name.String())
} }
if desc := funcDecl.Get("description"); desc.Exists() { if desc := funcDecl.Get("description"); desc.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String()) anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", desc.String())
} }
if params := funcDecl.Get("parameters"); params.Exists() { if params := funcDecl.Get("parameters"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility // Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw cleaned := []byte(params.Raw)
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false) cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#") cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned) anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() { } else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility // Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw cleaned := []byte(params.Raw)
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false) cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#") cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned) anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
} }
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value()) anthropicTools = append(anthropicTools, gjson.ParseBytes(anthropicTool).Value())
return true return true
}) })
} }
@@ -381,7 +381,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}) })
if len(anthropicTools) > 0 { if len(anthropicTools) > 0 {
out, _ = sjson.Set(out, "tools", anthropicTools) out, _ = sjson.SetBytes(out, "tools", anthropicTools)
} }
} }
@@ -391,27 +391,27 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() { if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() { switch mode.String() {
case "AUTO": case "AUTO":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "NONE": case "NONE":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"none"}`))
case "ANY": case "ANY":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
} }
} }
} }
} }
// Stream setting configuration // Stream setting configuration
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.SetBytes(out, "stream", stream)
// Convert tool parameter types to lowercase for Claude Code compatibility // Convert tool parameter types to lowercase for Claude Code compatibility
var pathsToLower []string var pathsToLower []string
toolsResult := gjson.Get(out, "tools") toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower) util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower { for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p) fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String())) out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
} }
return []byte(out) return out
} }

View File

@@ -9,10 +9,10 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"fmt"
"strings" "strings"
"time" "time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -30,7 +30,7 @@ type ConvertAnthropicResponseToGeminiParams struct {
Model string Model string
CreatedAt int64 CreatedAt int64
ResponseID string ResponseID string
LastStorageOutput string LastStorageOutput []byte
IsStreaming bool IsStreaming bool
// Streaming state for tool_use assembly // Streaming state for tool_use assembly
@@ -52,8 +52,8 @@ type ConvertAnthropicResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response // - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertAnthropicResponseToGeminiParams{ *param = &ConvertAnthropicResponseToGeminiParams{
Model: modelName, Model: modelName,
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,24 +71,24 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
eventType := root.Get("type").String() eventType := root.Get("type").String()
// Base Gemini response template with default values // Base Gemini response template with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}` template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version // Set model version
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" { if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
// Map Claude model names back to Gemini model names // Map Claude model names back to Gemini model names
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model) template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
} }
// Set response ID and creation time // Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" { if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID) template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
} }
// Set creation time to current time if not provided // Set creation time to current time if not provided
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 { if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix() (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
} }
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano)) template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
switch eventType { switch eventType {
case "message_start": case "message_start":
@@ -97,7 +97,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String() (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String() (*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
} }
return []string{} return [][]byte{}
case "content_block_start": case "content_block_start":
// Start of a content block - record tool_use name by index for functionCall assembly // Start of a content block - record tool_use name by index for functionCall assembly
@@ -112,7 +112,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
} }
} }
return []string{} return [][]byte{}
case "content_block_delta": case "content_block_delta":
// Handle content delta (text, thinking, or tool use arguments) // Handle content delta (text, thinking, or tool use arguments)
@@ -123,16 +123,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
case "text_delta": case "text_delta":
// Regular text content delta for normal response text // Regular text content delta for normal response text
if text := delta.Get("text"); text.Exists() && text.String() != "" { if text := delta.Get("text"); text.Exists() && text.String() != "" {
textPart := `{"text":""}` textPart := []byte(`{"text":""}`)
textPart, _ = sjson.Set(textPart, "text", text.String()) textPart, _ = sjson.SetBytes(textPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", textPart)
} }
case "thinking_delta": case "thinking_delta":
// Thinking/reasoning content delta for models with reasoning capabilities // Thinking/reasoning content delta for models with reasoning capabilities
if text := delta.Get("thinking"); text.Exists() && text.String() != "" { if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
thinkingPart := `{"thought":true,"text":""}` thinkingPart := []byte(`{"thought":true,"text":""}`)
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String()) thinkingPart, _ = sjson.SetBytes(thinkingPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", thinkingPart)
} }
case "input_json_delta": case "input_json_delta":
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop // Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
@@ -149,10 +149,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if pj := delta.Get("partial_json"); pj.Exists() { if pj := delta.Get("partial_json"); pj.Exists() {
b.WriteString(pj.String()) b.WriteString(pj.String())
} }
return []string{} return [][]byte{}
} }
} }
return []string{template} return [][]byte{template}
case "content_block_stop": case "content_block_stop":
// End of content block - finalize tool calls if any // End of content block - finalize tool calls if any
@@ -170,16 +170,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
} }
if name != "" || argsTrim != "" { if name != "" || argsTrim != "" {
functionCall := `{"functionCall":{"name":"","args":{}}}` functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" { if name != "" {
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name) functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", name)
} }
if argsTrim != "" { if argsTrim != "" {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim) functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim))
} }
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template (*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// cleanup used state for this index // cleanup used state for this index
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil { if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx) delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
@@ -187,9 +187,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil { if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx) delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
} }
return []string{template} return [][]byte{template}
} }
return []string{} return [][]byte{}
case "message_delta": case "message_delta":
// Handle message-level changes (like stop reason and usage information) // Handle message-level changes (like stop reason and usage information)
@@ -197,15 +197,15 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if stopReason := delta.Get("stop_reason"); stopReason.Exists() { if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
switch stopReason.String() { switch stopReason.String() {
case "end_turn": case "end_turn":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "tool_use": case "tool_use":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "max_tokens": case "max_tokens":
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "MAX_TOKENS")
case "stop_sequence": case "stop_sequence":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
default: default:
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
} }
} }
} }
@@ -216,35 +216,35 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
outputTokens := usage.Get("output_tokens").Int() outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification // Set basic usage metadata according to Gemini API specification
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens) template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens) template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens) template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields) // Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() { if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int()) template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
} }
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() { if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count // Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int() existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int() totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens) template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
} }
// Add thinking tokens if present (for models with reasoning capabilities) // Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() { if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int()) template, _ = sjson.SetBytes(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
} }
// Set traffic type (required by Gemini API) // Set traffic type (required by Gemini API)
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT") template, _ = sjson.SetBytes(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
} }
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
return []string{template} return [][]byte{template}
case "message_stop": case "message_stop":
// Final message with usage information - no additional output needed // Final message with usage information - no additional output needed
return []string{} return [][]byte{}
case "error": case "error":
// Handle error responses and convert to Gemini error format // Handle error responses and convert to Gemini error format
errorMsg := root.Get("error.message").String() errorMsg := root.Get("error.message").String()
@@ -253,13 +253,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
// Create error response in Gemini format // Create error response in Gemini format
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}` errorResponse := []byte(`{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`)
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg) errorResponse, _ = sjson.SetBytes(errorResponse, "error.message", errorMsg)
return []string{errorResponse} return [][]byte{errorResponse}
default: default:
// Unknown event type, return empty response // Unknown event type, return empty response
return []string{} return [][]byte{}
} }
} }
@@ -275,13 +275,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata // - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Base Gemini response template for non-streaming with default values // Base Gemini response template for non-streaming with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}` template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version // Set model version
template, _ = sjson.Set(template, "modelVersion", modelName) template, _ = sjson.SetBytes(template, "modelVersion", modelName)
streamingEvents := make([][]byte, 0) streamingEvents := make([][]byte, 0)
@@ -304,15 +304,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
Model: modelName, Model: modelName,
CreatedAt: 0, CreatedAt: 0,
ResponseID: "", ResponseID: "",
LastStorageOutput: "", LastStorageOutput: nil,
IsStreaming: false, IsStreaming: false,
ToolUseNames: nil, ToolUseNames: nil,
ToolUseArgs: nil, ToolUseArgs: nil,
} }
// Process each streaming event and collect parts // Process each streaming event and collect parts
var allParts []string var allParts [][]byte
var finalUsageJSON string var finalUsageJSON []byte
var responseID string var responseID string
var createdAt int64 var createdAt int64
@@ -360,15 +360,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "text_delta": case "text_delta":
// Process regular text content // Process regular text content
if text := delta.Get("text"); text.Exists() && text.String() != "" { if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}` partJSON := []byte(`{"text":""}`)
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON) allParts = append(allParts, partJSON)
} }
case "thinking_delta": case "thinking_delta":
// Process reasoning/thinking content // Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" { if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}` partJSON := []byte(`{"thought":true,"text":""}`)
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON) allParts = append(allParts, partJSON)
} }
case "input_json_delta": case "input_json_delta":
@@ -402,12 +402,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
} }
} }
if name != "" || argsTrim != "" { if name != "" || argsTrim != "" {
functionCallJSON := `{"functionCall":{"name":"","args":{}}}` functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" { if name != "" {
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name) functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.name", name)
} }
if argsTrim != "" { if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim) functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim))
} }
allParts = append(allParts, functionCallJSON) allParts = append(allParts, functionCallJSON)
// cleanup used state for this index // cleanup used state for this index
@@ -422,35 +422,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "message_delta": case "message_delta":
// Extract final usage information using sjson for token counts and metadata // Extract final usage information using sjson for token counts and metadata
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageJSON := `{}` usageJSON := []byte(`{}`)
// Basic token counts for prompt and completion // Basic token counts for prompt and completion
inputTokens := usage.Get("input_tokens").Int() inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int() outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification // Set basic usage metadata according to Gemini API specification
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens) usageJSON, _ = sjson.SetBytes(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens) usageJSON, _ = sjson.SetBytes(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens) usageJSON, _ = sjson.SetBytes(usageJSON, "totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields) // Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() { if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int()) usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
} }
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() { if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count // Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int() existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int() totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens) usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", totalCacheTokens)
} }
// Add thinking tokens if present (for models with reasoning capabilities) // Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() { if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int()) usageJSON, _ = sjson.SetBytes(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
} }
// Set traffic type (required by Gemini API) // Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT") usageJSON, _ = sjson.SetBytes(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
finalUsageJSON = usageJSON finalUsageJSON = usageJSON
} }
@@ -459,10 +459,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set response metadata // Set response metadata
if responseID != "" { if responseID != "" {
template, _ = sjson.Set(template, "responseId", responseID) template, _ = sjson.SetBytes(template, "responseId", responseID)
} }
if createdAt > 0 { if createdAt > 0 {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano)) template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
} }
// Consolidate consecutive text parts and thinking parts for cleaner output // Consolidate consecutive text parts and thinking parts for cleaner output
@@ -470,35 +470,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array // Set the consolidated parts array
if len(consolidatedParts) > 0 { if len(consolidatedParts) > 0 {
partsJSON := "[]" partsJSON := []byte(`[]`)
for _, partJSON := range consolidatedParts { for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON) partsJSON, _ = sjson.SetRawBytes(partsJSON, "-1", partJSON)
} }
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts", partsJSON)
} }
// Set usage metadata // Set usage metadata
if finalUsageJSON != "" { if len(finalUsageJSON) > 0 {
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON) template, _ = sjson.SetRawBytes(template, "usageMetadata", finalUsageJSON)
} }
return template return template
} }
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return translatorcommon.GeminiTokenCountJSON(count)
} }
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response. // consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
// This function processes the parts array to combine adjacent text elements and thinking elements // This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure. // into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements. // Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []string) []string { func consolidateParts(parts [][]byte) [][]byte {
if len(parts) == 0 { if len(parts) == 0 {
return parts return parts
} }
var consolidated []string var consolidated [][]byte
var currentTextPart strings.Builder var currentTextPart strings.Builder
var currentThoughtPart strings.Builder var currentThoughtPart strings.Builder
var hasText, hasThought bool var hasText, hasThought bool
@@ -506,8 +506,8 @@ func consolidateParts(parts []string) []string {
flushText := func() { flushText := func() {
// Flush accumulated text content to the consolidated parts array // Flush accumulated text content to the consolidated parts array
if hasText && currentTextPart.Len() > 0 { if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}` textPartJSON := []byte(`{"text":""}`)
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String()) textPartJSON, _ = sjson.SetBytes(textPartJSON, "text", currentTextPart.String())
consolidated = append(consolidated, textPartJSON) consolidated = append(consolidated, textPartJSON)
currentTextPart.Reset() currentTextPart.Reset()
hasText = false hasText = false
@@ -517,8 +517,8 @@ func consolidateParts(parts []string) []string {
flushThought := func() { flushThought := func() {
// Flush accumulated thinking content to the consolidated parts array // Flush accumulated thinking content to the consolidated parts array
if hasThought && currentThoughtPart.Len() > 0 { if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}` thoughtPartJSON := []byte(`{"thought":true,"text":""}`)
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String()) thoughtPartJSON, _ = sjson.SetBytes(thoughtPartJSON, "text", currentThoughtPart.String())
consolidated = append(consolidated, thoughtPartJSON) consolidated = append(consolidated, thoughtPartJSON)
currentThoughtPart.Reset() currentThoughtPart.Reset()
hasThought = false hasThought = false
@@ -526,7 +526,7 @@ func consolidateParts(parts []string) []string {
} }
for _, partJSON := range parts { for _, partJSON := range parts {
part := gjson.Parse(partJSON) part := gjson.ParseBytes(partJSON)
if !part.Exists() || !part.IsObject() { if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part // Flush any pending parts and add this non-text part
flushText() flushText()

View File

@@ -61,7 +61,7 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session) userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude Code API template with default max_tokens value // Base Claude Code API template with default max_tokens value
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID) out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -79,20 +79,20 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive { if supportsAdaptive {
switch effort { switch effort {
case "none": case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto": case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
default: default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok { if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped effort = mapped
} }
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort) out, _ = sjson.SetBytes(out, "output_config.effort", effort)
} }
} else { } else {
// Legacy/manual thinking (budget_tokens). // Legacy/manual thinking (budget_tokens).
@@ -100,13 +100,13 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if ok { if ok {
switch budget { switch budget {
case 0: case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1: case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default: default:
if budget > 0 { if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget) out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
} }
} }
} }
@@ -128,19 +128,19 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
} }
// Model mapping to specify which Claude Code model to use // Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens configuration with fallback to default value // Max tokens configuration with fallback to default value
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() { if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int()) out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
} }
// Temperature setting for controlling response randomness // Temperature setting for controlling response randomness
if temp := root.Get("temperature"); temp.Exists() { if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float()) out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := root.Get("top_p"); topP.Exists() { } else if topP := root.Get("top_p"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set) // Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float()) out, _ = sjson.SetBytes(out, "top_p", topP.Float())
} }
// Stop sequences configuration for custom termination conditions // Stop sequences configuration for custom termination conditions
@@ -152,15 +152,15 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return true return true
}) })
if len(stopSequences) > 0 { if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences) out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
} }
} else { } else {
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()}) out, _ = sjson.SetBytes(out, "stop_sequences", []string{stop.String()})
} }
} }
// Stream configuration to enable or disable streaming responses // Stream configuration to enable or disable streaming responses
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.SetBytes(out, "stream", stream)
// Process messages and transform them to Claude Code format // Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() { if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
@@ -173,39 +173,39 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
switch role { switch role {
case "system": case "system":
if systemMessageIndex == -1 { if systemMessageIndex == -1 {
systemMsg := `{"role":"user","content":[]}` systemMsg := []byte(`{"role":"user","content":[]}`)
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg) out, _ = sjson.SetRawBytes(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex systemMessageIndex = messageIndex
messageIndex++ messageIndex++
} }
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" { if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := `{"type":"text","text":""}` textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.Set(textPart, "text", contentResult.String()) textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart) out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
} else if contentResult.Exists() && contentResult.IsArray() { } else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool { contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" { if part.Get("type").String() == "text" {
textPart := `{"type":"text","text":""}` textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String()) textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart) out, _ = sjson.SetRawBytes(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
} }
return true return true
}) })
} }
case "user", "assistant": case "user", "assistant":
msg := `{"role":"","content":[]}` msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
// Handle content based on its type (string or array) // Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" { if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
part := `{"type":"text","text":""}` part := []byte(`{"type":"text","text":""}`)
part, _ = sjson.Set(part, "text", contentResult.String()) part, _ = sjson.SetBytes(part, "text", contentResult.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() { } else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool { contentResult.ForEach(func(_, part gjson.Result) bool {
claudePart := convertOpenAIContentPartToClaudePart(part) claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" { if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart) msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(claudePart))
} }
return true return true
}) })
@@ -221,9 +221,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
} }
function := toolCall.Get("function") function := toolCall.Get("function")
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.Set(toolUse, "id", toolCallID) toolUse, _ = sjson.SetBytes(toolUse, "id", toolCallID)
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String()) toolUse, _ = sjson.SetBytes(toolUse, "name", function.Get("name").String())
// Parse arguments for the tool call // Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() { if args := function.Get("arguments"); args.Exists() {
@@ -231,24 +231,24 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if argsStr != "" && gjson.Valid(argsStr) { if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr) argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() { if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw) toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
} else { } else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}") toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
} }
} else { } else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}") toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
} }
} else { } else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}") toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
} }
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse) msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
} }
return true return true
}) })
} }
out, _ = sjson.SetRaw(out, "messages.-1", msg) out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++ messageIndex++
case "tool": case "tool":
@@ -256,15 +256,15 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
toolCallID := message.Get("tool_call_id").String() toolCallID := message.Get("tool_call_id").String()
toolContentResult := message.Get("content") toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}` msg := []byte(`{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`)
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID) msg, _ = sjson.SetBytes(msg, "content.0.tool_use_id", toolCallID)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult) toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw { if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent) msg, _ = sjson.SetRawBytes(msg, "content.0.content", []byte(toolResultContent))
} else { } else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent) msg, _ = sjson.SetBytes(msg, "content.0.content", toolResultContent)
} }
out, _ = sjson.SetRaw(out, "messages.-1", msg) out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++ messageIndex++
} }
return true return true
@@ -277,25 +277,25 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" { if tool.Get("type").String() == "function" {
function := tool.Get("function") function := tool.Get("function")
anthropicTool := `{"name":"","description":""}` anthropicTool := []byte(`{"name":"","description":""}`)
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String()) anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String()) anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", function.Get("description").String())
// Convert parameters schema for the tool // Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() { if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw) anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() { } else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw) anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
} }
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool) out, _ = sjson.SetRawBytes(out, "tools.-1", anthropicTool)
hasAnthropicTools = true hasAnthropicTools = true
} }
return true return true
}) })
if !hasAnthropicTools { if !hasAnthropicTools {
out, _ = sjson.Delete(out, "tools") out, _ = sjson.DeleteBytes(out, "tools")
} }
} }
@@ -308,31 +308,31 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none": case "none":
// Don't set tool_choice, Claude Code will not use tools // Don't set tool_choice, Claude Code will not use tools
case "auto": case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "required": case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
} }
case gjson.JSON: case gjson.JSON:
// Specific tool choice mapping // Specific tool choice mapping
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String() functionName := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"type":"tool","name":""}` toolChoiceJSON := []byte(`{"type":"tool","name":""}`)
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName) toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON) out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
} }
default: default:
} }
} }
return []byte(out) return out
} }
func convertOpenAIContentPartToClaudePart(part gjson.Result) string { func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() { switch part.Get("type").String() {
case "text": case "text":
textPart := `{"type":"text","text":""}` textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String()) textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
return textPart return string(textPart)
case "image_url": case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String()) return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
@@ -345,10 +345,10 @@ func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx { if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:") mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:] data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` docPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType) docPart, _ = sjson.SetBytes(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data) docPart, _ = sjson.SetBytes(docPart, "source.data", data)
return docPart return string(docPart)
} }
} }
} }
@@ -373,15 +373,15 @@ func convertOpenAIImageURLToClaudePart(imageURL string) string {
mediaType = "application/octet-stream" mediaType = "application/octet-stream"
} }
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}` imagePart := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType) imagePart, _ = sjson.SetBytes(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1]) imagePart, _ = sjson.SetBytes(imagePart, "source.data", parts[1])
return imagePart return string(imagePart)
} }
imagePart := `{"type":"image","source":{"type":"url","url":""}}` imagePart := []byte(`{"type":"image","source":{"type":"url","url":""}}`)
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL) imagePart, _ = sjson.SetBytes(imagePart, "source.url", imageURL)
return imagePart return string(imagePart)
} }
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) { func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
@@ -394,28 +394,28 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
} }
if content.IsArray() { if content.IsArray() {
claudeContent := "[]" claudeContent := []byte("[]")
partCount := 0 partCount := 0
content.ForEach(func(_, part gjson.Result) bool { content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String { if part.Type == gjson.String {
textPart := `{"type":"text","text":""}` textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.Set(textPart, "text", part.String()) textPart, _ = sjson.SetBytes(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart) claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", textPart)
partCount++ partCount++
return true return true
} }
claudePart := convertOpenAIContentPartToClaudePart(part) claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" { if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart) claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
partCount++ partCount++
} }
return true return true
}) })
if partCount > 0 || len(content.Array()) == 0 { if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true return string(claudeContent), true
} }
return content.Raw, false return content.Raw, false
@@ -424,9 +424,9 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if content.IsObject() { if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content) claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" { if claudePart != "" {
claudeContent := "[]" claudeContent := []byte("[]")
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart) claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
return claudeContent, true return string(claudeContent), true
} }
return content.Raw, false return content.Raw, false
} }

View File

@@ -36,6 +36,18 @@ type ToolCallAccumulator struct {
Arguments strings.Builder Arguments strings.Builder
} }
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
inputTokens := usage.Get("input_tokens").Int()
completionTokens = usage.Get("output_tokens").Int()
cachedTokens = usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
totalTokens = promptTokens + completionTokens
return promptTokens, completionTokens, totalTokens, cachedTokens
}
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format. // ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses. // This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match // It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
@@ -48,8 +60,8 @@ type ToolCallAccumulator struct {
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response // - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
var localParam any var localParam any
if param == nil { if param == nil {
param = &localParam param = &localParam
@@ -63,7 +75,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
} }
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,19 +83,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
eventType := root.Get("type").String() eventType := root.Get("type").String()
// Base OpenAI streaming response template // Base OpenAI streaming response template
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}` template := []byte(`{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`)
// Set model // Set model
if modelName != "" { if modelName != "" {
template, _ = sjson.Set(template, "model", modelName) template, _ = sjson.SetBytes(template, "model", modelName)
} }
// Set response ID and creation time // Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" { if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID) template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
} }
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 { if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt) template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
} }
switch eventType { switch eventType {
@@ -93,19 +105,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String() (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix() (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID) template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.Set(template, "model", modelName) template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt) template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
// Set initial role to assistant for the response // Set initial role to assistant for the response
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
// Initialize tool calls accumulator for tracking tool call progress // Initialize tool calls accumulator for tracking tool call progress
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil { if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator) (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
} }
} }
return []string{template} return [][]byte{template}
case "content_block_start": case "content_block_start":
// Start of a content block (text, tool use, or reasoning) // Start of a content block (text, tool use, or reasoning)
@@ -128,10 +140,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
} }
// Don't output anything yet - wait for complete tool call // Don't output anything yet - wait for complete tool call
return []string{} return [][]byte{}
} }
} }
return []string{} return [][]byte{}
case "content_block_delta": case "content_block_delta":
// Handle content delta (text, tool use arguments, or reasoning content) // Handle content delta (text, tool use arguments, or reasoning content)
@@ -143,13 +155,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "text_delta": case "text_delta":
// Text content delta - send incremental text updates // Text content delta - send incremental text updates
if text := delta.Get("text"); text.Exists() { if text := delta.Get("text"); text.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.content", text.String()) template, _ = sjson.SetBytes(template, "choices.0.delta.content", text.String())
hasContent = true hasContent = true
} }
case "thinking_delta": case "thinking_delta":
// Accumulate reasoning/thinking content // Accumulate reasoning/thinking content
if thinking := delta.Get("thinking"); thinking.Exists() { if thinking := delta.Get("thinking"); thinking.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String()) template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", thinking.String())
hasContent = true hasContent = true
} }
case "input_json_delta": case "input_json_delta":
@@ -163,13 +175,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
} }
} }
// Don't output anything yet - wait for complete tool call // Don't output anything yet - wait for complete tool call
return []string{} return [][]byte{}
} }
} }
if hasContent { if hasContent {
return []string{template} return [][]byte{template}
} else { } else {
return []string{} return [][]byte{}
} }
case "content_block_stop": case "content_block_stop":
@@ -182,63 +194,60 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" { if arguments == "" {
arguments = "{}" arguments = "{}"
} }
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index) template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID) template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function") template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name) template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments) template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
// Clean up the accumulator for this index // Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index) delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
return []string{template} return [][]byte{template}
} }
} }
return []string{} return [][]byte{}
case "message_delta": case "message_delta":
// Handle message-level changes including stop reason and usage // Handle message-level changes including stop reason and usage
if delta := root.Get("delta"); delta.Exists() { if delta := root.Get("delta"); delta.Exists() {
if stopReason := delta.Get("stop_reason"); stopReason.Exists() { if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String()) (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason) template, _ = sjson.SetBytes(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
} }
} }
// Handle usage information for token counts // Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int() promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
outputTokens := usage.Get("output_tokens").Int() template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int() template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens) template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
return []string{template} return [][]byte{template}
case "message_stop": case "message_stop":
// Final message event - no additional output needed // Final message event - no additional output needed
return []string{} return [][]byte{}
case "ping": case "ping":
// Ping events for keeping connection alive - no output needed // Ping events for keeping connection alive - no output needed
return []string{} return [][]byte{}
case "error": case "error":
// Error event - format and return error response // Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() { if errorData := root.Get("error"); errorData.Exists() {
errorJSON := `{"error":{"message":"","type":""}}` errorJSON := []byte(`{"error":{"message":"","type":""}}`)
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String()) errorJSON, _ = sjson.SetBytes(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String()) errorJSON, _ = sjson.SetBytes(errorJSON, "error.type", errorData.Get("type").String())
return []string{errorJSON} return [][]byte{errorJSON}
} }
return []string{} return [][]byte{}
default: default:
// Unknown event type - ignore // Unknown event type - ignore
return []string{} return [][]byte{}
} }
} }
@@ -270,8 +279,8 @@ func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata // - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
chunks := make([][]byte, 0) chunks := make([][]byte, 0)
lines := bytes.Split(rawJSON, []byte("\n")) lines := bytes.Split(rawJSON, []byte("\n"))
@@ -283,7 +292,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
// Base OpenAI non-streaming response template // Base OpenAI non-streaming response template
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` out := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
var messageID string var messageID string
var model string var model string
@@ -366,32 +375,29 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
} }
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int() promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
outputTokens := usage.Get("output_tokens").Int() out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int() out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int() out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens) out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
} }
} }
// Set basic response fields including message ID, creation time, and model // Set basic response fields including message ID, creation time, and model
out, _ = sjson.Set(out, "id", messageID) out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.Set(out, "created", createdAt) out, _ = sjson.SetBytes(out, "created", createdAt)
out, _ = sjson.Set(out, "model", model) out, _ = sjson.SetBytes(out, "model", model)
// Set message content by combining all text parts // Set message content by combining all text parts
messageContent := strings.Join(contentParts, "") messageContent := strings.Join(contentParts, "")
out, _ = sjson.Set(out, "choices.0.message.content", messageContent) out, _ = sjson.SetBytes(out, "choices.0.message.content", messageContent)
// Add reasoning content if available (following OpenAI reasoning format) // Add reasoning content if available (following OpenAI reasoning format)
if len(reasoningParts) > 0 { if len(reasoningParts) > 0 {
reasoningContent := strings.Join(reasoningParts, "") reasoningContent := strings.Join(reasoningParts, "")
// Add reasoning as a separate field in the message // Add reasoning as a separate field in the message
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent) out, _ = sjson.SetBytes(out, "choices.0.message.reasoning", reasoningContent)
} }
// Set tool calls if any were accumulated during processing // Set tool calls if any were accumulated during processing
@@ -417,19 +423,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount) namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount) argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID) out, _ = sjson.SetBytes(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function") out, _ = sjson.SetBytes(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name) out, _ = sjson.SetBytes(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments) out, _ = sjson.SetBytes(out, argumentsPath, arguments)
toolCallsCount++ toolCallsCount++
} }
if toolCallsCount > 0 { if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls") out, _ = sjson.SetBytes(out, "choices.0.finish_reason", "tool_calls")
} else { } else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
} }
} else { } else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
} }
return out return out

View File

@@ -0,0 +1,58 @@
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`),
&param,
)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}

View File

@@ -49,7 +49,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session) userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload // Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID) out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -67,20 +67,20 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if supportsAdaptive { if supportsAdaptive {
switch effort { switch effort {
case "none": case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto": case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort") out, _ = sjson.DeleteBytes(out, "output_config.effort")
default: default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok { if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped effort = mapped
} }
out, _ = sjson.Set(out, "thinking.type", "adaptive") out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens") out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort) out, _ = sjson.SetBytes(out, "output_config.effort", effort)
} }
} else { } else {
// Legacy/manual thinking (budget_tokens). // Legacy/manual thinking (budget_tokens).
@@ -88,13 +88,13 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if ok { if ok {
switch budget { switch budget {
case 0: case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled") out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1: case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default: default:
if budget > 0 { if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled") out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget) out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
} }
} }
} }
@@ -114,15 +114,15 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
} }
// Model // Model
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens // Max tokens
if mot := root.Get("max_output_tokens"); mot.Exists() { if mot := root.Get("max_output_tokens"); mot.Exists() {
out, _ = sjson.Set(out, "max_tokens", mot.Int()) out, _ = sjson.SetBytes(out, "max_tokens", mot.Int())
} }
// Stream // Stream
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.SetBytes(out, "stream", stream)
// instructions -> as a leading message (use role user for Claude API compatibility) // instructions -> as a leading message (use role user for Claude API compatibility)
instructionsText := "" instructionsText := ""
@@ -130,9 +130,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String { if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
instructionsText = instr.String() instructionsText = instr.String()
if instructionsText != "" { if instructionsText != "" {
sysMsg := `{"role":"user","content":""}` sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText) sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg) out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
} }
} }
@@ -156,9 +156,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
} }
instructionsText = builder.String() instructionsText = builder.String()
if instructionsText != "" { if instructionsText != "" {
sysMsg := `{"role":"user","content":""}` sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText) sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg) out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
extractedFromSystem = true extractedFromSystem = true
} }
} }
@@ -193,9 +193,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if t := part.Get("text"); t.Exists() { if t := part.Get("text"); t.Exists() {
txt := t.String() txt := t.String()
textAggregate.WriteString(txt) textAggregate.WriteString(txt)
contentPart := `{"type":"text","text":""}` contentPart := []byte(`{"type":"text","text":""}`)
contentPart, _ = sjson.Set(contentPart, "text", txt) contentPart, _ = sjson.SetBytes(contentPart, "text", txt)
partsJSON = append(partsJSON, contentPart) partsJSON = append(partsJSON, string(contentPart))
} }
if ptype == "input_text" { if ptype == "input_text" {
role = "user" role = "user"
@@ -208,7 +208,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
url = part.Get("url").String() url = part.Get("url").String()
} }
if url != "" { if url != "" {
var contentPart string var contentPart []byte
if strings.HasPrefix(url, "data:") { if strings.HasPrefix(url, "data:") {
trimmed := strings.TrimPrefix(url, "data:") trimmed := strings.TrimPrefix(url, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2) mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
@@ -221,16 +221,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1] data = mediaAndData[1]
} }
if data != "" { if data != "" {
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}` contentPart = []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType) contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data) contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
} }
} else { } else {
contentPart = `{"type":"image","source":{"type":"url","url":""}}` contentPart = []byte(`{"type":"image","source":{"type":"url","url":""}}`)
contentPart, _ = sjson.Set(contentPart, "source.url", url) contentPart, _ = sjson.SetBytes(contentPart, "source.url", url)
} }
if contentPart != "" { if len(contentPart) > 0 {
partsJSON = append(partsJSON, contentPart) partsJSON = append(partsJSON, string(contentPart))
if role == "" { if role == "" {
role = "user" role = "user"
} }
@@ -252,10 +252,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1] data = mediaAndData[1]
} }
} }
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` contentPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType) contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data) contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart) partsJSON = append(partsJSON, string(contentPart))
if role == "" { if role == "" {
role = "user" role = "user"
} }
@@ -280,24 +280,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
} }
if len(partsJSON) > 0 { if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}` msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
if len(partsJSON) == 1 && !hasImage && !hasFile { if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content // Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content") msg, _ = sjson.DeleteBytes(msg, "content")
textPart := gjson.Parse(partsJSON[0]) textPart := gjson.Parse(partsJSON[0])
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String()) msg, _ = sjson.SetBytes(msg, "content", textPart.Get("text").String())
} else { } else {
for _, partJSON := range partsJSON { for _, partJSON := range partsJSON {
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON) msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
} }
} }
out, _ = sjson.SetRaw(out, "messages.-1", msg) out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
} else if textAggregate.Len() > 0 || role == "system" { } else if textAggregate.Len() > 0 || role == "system" {
msg := `{"role":"","content":""}` msg := []byte(`{"role":"","content":""}`)
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
msg, _ = sjson.Set(msg, "content", textAggregate.String()) msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
out, _ = sjson.SetRaw(out, "messages.-1", msg) out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
} }
case "function_call": case "function_call":
@@ -309,31 +309,31 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
name := item.Get("name").String() name := item.Get("name").String()
argsStr := item.Get("arguments").String() argsStr := item.Get("arguments").String()
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.Set(toolUse, "id", callID) toolUse, _ = sjson.SetBytes(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name) toolUse, _ = sjson.SetBytes(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) { if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr) argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() { if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw) toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
} }
} }
asst := `{"role":"assistant","content":[]}` asst := []byte(`{"role":"assistant","content":[]}`)
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse) asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
out, _ = sjson.SetRaw(out, "messages.-1", asst) out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
case "function_call_output": case "function_call_output":
// Map to user tool_result // Map to user tool_result
callID := item.Get("call_id").String() callID := item.Get("call_id").String()
outputStr := item.Get("output").String() outputStr := item.Get("output").String()
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}` toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID) toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.Set(toolResult, "content", outputStr) toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
usr := `{"role":"user","content":[]}` usr := []byte(`{"role":"user","content":[]}`)
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult) usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
out, _ = sjson.SetRaw(out, "messages.-1", usr) out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
} }
return true return true
}) })
@@ -341,27 +341,27 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
// tools mapping: parameters -> input_schema // tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := "[]" toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := `{"name":"","description":"","input_schema":{}}` tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
if n := tool.Get("name"); n.Exists() { if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.Set(tJSON, "name", n.String()) tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
} }
if d := tool.Get("description"); d.Exists() { if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.Set(tJSON, "description", d.String()) tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
} }
if params := tool.Get("parameters"); params.Exists() { if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw) tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} else if params = tool.Get("parametersJsonSchema"); params.Exists() { } else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw) tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} }
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON) toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
return true return true
}) })
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 { if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", toolsJSON) out, _ = sjson.SetRawBytes(out, "tools", toolsJSON)
} }
} }
@@ -371,23 +371,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String: case gjson.String:
switch toolChoice.String() { switch toolChoice.String() {
case "auto": case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "none": case "none":
// Leave unset; implies no tools // Leave unset; implies no tools
case "required": case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
} }
case gjson.JSON: case gjson.JSON:
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String() fn := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"name":"","type":"tool"}` toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn) toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON) out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
} }
default: default:
} }
} }
return []byte(out) return out
} }

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -50,12 +51,12 @@ func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
return nil return nil
} }
func emitEvent(event string, payload string) string { func emitEvent(event string, payload []byte) []byte {
return fmt.Sprintf("event: %s\ndata: %s", event, payload) return translatorcommon.SSEEventData(event, payload)
} }
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events. // ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)} *param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
} }
@@ -63,12 +64,12 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// Expect `data: {..}` from Claude clients // Expect `data: {..}` from Claude clients
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
ev := root.Get("type").String() ev := root.Get("type").String()
var out []string var out [][]byte
nextSeq := func() int { st.Seq++; return st.Seq } nextSeq := func() int { st.Seq++; return st.Seq }
@@ -105,16 +106,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
} }
// response.created // response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}` created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.SetBytes(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt) created, _ = sjson.SetBytes(created, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.created", created)) out = append(out, emitEvent("response.created", created))
// response.in_progress // response.in_progress
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}` inprog := []byte(`{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`)
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq()) inprog, _ = sjson.SetBytes(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID) inprog, _ = sjson.SetBytes(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt) inprog, _ = sjson.SetBytes(inprog, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.in_progress", inprog)) out = append(out, emitEvent("response.in_progress", inprog))
} }
case "content_block_start": case "content_block_start":
@@ -128,25 +129,25 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// open message item + content part // open message item + content part
st.InTextBlock = true st.InTextBlock = true
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID) st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}` item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.Set(item, "sequence_number", nextSeq()) item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID) item, _ = sjson.SetBytes(item, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.added", item)) out = append(out, emitEvent("response.output_item.added", item))
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
part, _ = sjson.Set(part, "sequence_number", nextSeq()) part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID) part, _ = sjson.SetBytes(part, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.added", part)) out = append(out, emitEvent("response.content_part.added", part))
} else if typ == "tool_use" { } else if typ == "tool_use" {
st.InFuncBlock = true st.InFuncBlock = true
st.CurrentFCID = cb.Get("id").String() st.CurrentFCID = cb.Get("id").String()
name := cb.Get("name").String() name := cb.Get("name").String()
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}` item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
item, _ = sjson.Set(item, "sequence_number", nextSeq()) item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx) item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID)) item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID) item, _ = sjson.SetBytes(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.Set(item, "item.name", name) item, _ = sjson.SetBytes(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item)) out = append(out, emitEvent("response.output_item.added", item))
if st.FuncArgsBuf[idx] == nil { if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{} st.FuncArgsBuf[idx] = &strings.Builder{}
@@ -160,16 +161,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.ReasoningIndex = idx st.ReasoningIndex = idx
st.ReasoningBuf.Reset() st.ReasoningBuf.Reset()
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx) st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}` item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
item, _ = sjson.Set(item, "sequence_number", nextSeq()) item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx) item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID) item, _ = sjson.SetBytes(item, "item.id", st.ReasoningItemID)
out = append(out, emitEvent("response.output_item.added", item)) out = append(out, emitEvent("response.output_item.added", item))
// add a summary part placeholder // add a summary part placeholder
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}` part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
part, _ = sjson.Set(part, "sequence_number", nextSeq()) part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID) part, _ = sjson.SetBytes(part, "item_id", st.ReasoningItemID)
part, _ = sjson.Set(part, "output_index", idx) part, _ = sjson.SetBytes(part, "output_index", idx)
out = append(out, emitEvent("response.reasoning_summary_part.added", part)) out = append(out, emitEvent("response.reasoning_summary_part.added", part))
st.ReasoningPartAdded = true st.ReasoningPartAdded = true
} }
@@ -181,10 +182,10 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
dt := d.Get("type").String() dt := d.Get("type").String()
if dt == "text_delta" { if dt == "text_delta" {
if t := d.Get("text"); t.Exists() { if t := d.Get("text"); t.Exists() {
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}` msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID) msg, _ = sjson.SetBytes(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.Set(msg, "delta", t.String()) msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.output_text.delta", msg)) out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output // aggregate text for response.output
st.TextBuf.WriteString(t.String()) st.TextBuf.WriteString(t.String())
@@ -196,22 +197,22 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.FuncArgsBuf[idx] = &strings.Builder{} st.FuncArgsBuf[idx] = &strings.Builder{}
} }
st.FuncArgsBuf[idx].WriteString(pj.String()) st.FuncArgsBuf[idx].WriteString(pj.String())
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}` msg := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID)) msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.Set(msg, "output_index", idx) msg, _ = sjson.SetBytes(msg, "output_index", idx)
msg, _ = sjson.Set(msg, "delta", pj.String()) msg, _ = sjson.SetBytes(msg, "delta", pj.String())
out = append(out, emitEvent("response.function_call_arguments.delta", msg)) out = append(out, emitEvent("response.function_call_arguments.delta", msg))
} }
} else if dt == "thinking_delta" { } else if dt == "thinking_delta" {
if st.ReasoningActive { if st.ReasoningActive {
if t := d.Get("thinking"); t.Exists() { if t := d.Get("thinking"); t.Exists() {
st.ReasoningBuf.WriteString(t.String()) st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}` msg := []byte(`{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`)
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID) msg, _ = sjson.SetBytes(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) msg, _ = sjson.SetBytes(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", t.String()) msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg)) out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
} }
} }
@@ -219,17 +220,17 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop": case "content_block_stop":
idx := int(root.Get("index").Int()) idx := int(root.Get("index").Int())
if st.InTextBlock { if st.InTextBlock {
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.Set(done, "sequence_number", nextSeq()) done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID) done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_text.done", done)) out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID) partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.done", partDone)) out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}` final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.Set(final, "sequence_number", nextSeq()) final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID) final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.done", final)) out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false st.InTextBlock = false
} else if st.InFuncBlock { } else if st.InFuncBlock {
@@ -239,34 +240,34 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
args = buf.String() args = buf.String()
} }
} }
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}` fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq()) fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID)) fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.Set(fcDone, "output_index", idx) fcDone, _ = sjson.SetBytes(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", args) fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone)) out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}` itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx) itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID)) itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args) itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID) itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx]) itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone)) out = append(out, emitEvent("response.output_item.done", itemDone))
st.InFuncBlock = false st.InFuncBlock = false
} else if st.ReasoningActive { } else if st.ReasoningActive {
full := st.ReasoningBuf.String() full := st.ReasoningBuf.String()
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` textDone := []byte(`{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`)
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq()) textDone, _ = sjson.SetBytes(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID) textDone, _ = sjson.SetBytes(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex) textDone, _ = sjson.SetBytes(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", full) textDone, _ = sjson.SetBytes(textDone, "text", full)
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone)) out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}` partDone := []byte(`{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID) partDone, _ = sjson.SetBytes(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex) partDone, _ = sjson.SetBytes(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", full) partDone, _ = sjson.SetBytes(partDone, "part.text", full)
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone)) out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
st.ReasoningActive = false st.ReasoningActive = false
st.ReasoningPartAdded = false st.ReasoningPartAdded = false
@@ -284,92 +285,92 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
case "message_stop": case "message_stop":
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}` completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.Set(completed, "sequence_number", nextSeq()) completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID) completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt) completed, _ = sjson.SetBytes(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json // Inject original request fields into response as per docs/response.completed.json
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON) reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 { if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes) req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() { if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String()) completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
} }
if v := req.Get("max_output_tokens"); v.Exists() { if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int()) completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
} }
if v := req.Get("max_tool_calls"); v.Exists() { if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int()) completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
} }
if v := req.Get("model"); v.Exists() { if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String()) completed, _ = sjson.SetBytes(completed, "response.model", v.String())
} }
if v := req.Get("parallel_tool_calls"); v.Exists() { if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool()) completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
} }
if v := req.Get("previous_response_id"); v.Exists() { if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String()) completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
} }
if v := req.Get("prompt_cache_key"); v.Exists() { if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String()) completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
} }
if v := req.Get("reasoning"); v.Exists() { if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value()) completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
} }
if v := req.Get("safety_identifier"); v.Exists() { if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String()) completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
} }
if v := req.Get("service_tier"); v.Exists() { if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String()) completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
} }
if v := req.Get("store"); v.Exists() { if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool()) completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
} }
if v := req.Get("temperature"); v.Exists() { if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float()) completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
} }
if v := req.Get("text"); v.Exists() { if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value()) completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
} }
if v := req.Get("tool_choice"); v.Exists() { if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value()) completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
} }
if v := req.Get("tools"); v.Exists() { if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value()) completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
} }
if v := req.Get("top_logprobs"); v.Exists() { if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int()) completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
} }
if v := req.Get("top_p"); v.Exists() { if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float()) completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
} }
if v := req.Get("truncation"); v.Exists() { if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String()) completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
} }
if v := req.Get("user"); v.Exists() { if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value()) completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
} }
if v := req.Get("metadata"); v.Exists() { if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value()) completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
} }
} }
// Build response.output from aggregated state // Build response.output from aggregated state
outputsWrapper := `{"arr":[]}` outputsWrapper := []byte(`{"arr":[]}`)
// reasoning item (if any) // reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded { if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}` item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.Set(item, "id", st.ReasoningItemID) item, _ = sjson.SetBytes(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String()) item, _ = sjson.SetBytes(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
// assistant message item (if any text) // assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" { if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.Set(item, "id", st.CurrentMsgID) item, _ = sjson.SetBytes(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String()) item, _ = sjson.SetBytes(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
// function_call items (in ascending index order for determinism) // function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
@@ -396,16 +397,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" { if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID callID = st.CurrentFCID
} }
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}` item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID)) item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args) item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID) item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name) item, _ = sjson.SetBytes(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
} }
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 { if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw) completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
} }
reasoningTokens := int64(0) reasoningTokens := int64(0)
@@ -414,15 +415,15 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
usagePresent := st.UsageSeen || reasoningTokens > 0 usagePresent := st.UsageSeen || reasoningTokens > 0
if usagePresent { if usagePresent {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens) completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0) completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens) completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.OutputTokens)
if reasoningTokens > 0 { if reasoningTokens > 0 {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens) completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
} }
total := st.InputTokens + st.OutputTokens total := st.InputTokens + st.OutputTokens
if total > 0 || st.UsageSeen { if total > 0 || st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total) completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
} }
} }
out = append(out, emitEvent("response.completed", completed)) out = append(out, emitEvent("response.completed", completed))
@@ -432,7 +433,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON. // ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream) // Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
// We follow the same aggregation logic as the streaming variant but produce // We follow the same aggregation logic as the streaming variant but produce
// one final object matching docs/out.json structure. // one final object matching docs/out.json structure.
@@ -455,7 +456,7 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Base OpenAI Responses (non-stream) object // Base OpenAI Responses (non-stream) object
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}` out := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`)
// Aggregation state // Aggregation state
var ( var (
@@ -557,88 +558,88 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Populate base fields // Populate base fields
out, _ = sjson.Set(out, "id", responseID) out, _ = sjson.SetBytes(out, "id", responseID)
out, _ = sjson.Set(out, "created_at", createdAt) out, _ = sjson.SetBytes(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant) // Inject request echo fields as top-level (similar to streaming variant)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON) reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 { if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes) req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() { if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String()) out, _ = sjson.SetBytes(out, "instructions", v.String())
} }
if v := req.Get("max_output_tokens"); v.Exists() { if v := req.Get("max_output_tokens"); v.Exists() {
out, _ = sjson.Set(out, "max_output_tokens", v.Int()) out, _ = sjson.SetBytes(out, "max_output_tokens", v.Int())
} }
if v := req.Get("max_tool_calls"); v.Exists() { if v := req.Get("max_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "max_tool_calls", v.Int()) out, _ = sjson.SetBytes(out, "max_tool_calls", v.Int())
} }
if v := req.Get("model"); v.Exists() { if v := req.Get("model"); v.Exists() {
out, _ = sjson.Set(out, "model", v.String()) out, _ = sjson.SetBytes(out, "model", v.String())
} }
if v := req.Get("parallel_tool_calls"); v.Exists() { if v := req.Get("parallel_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool()) out, _ = sjson.SetBytes(out, "parallel_tool_calls", v.Bool())
} }
if v := req.Get("previous_response_id"); v.Exists() { if v := req.Get("previous_response_id"); v.Exists() {
out, _ = sjson.Set(out, "previous_response_id", v.String()) out, _ = sjson.SetBytes(out, "previous_response_id", v.String())
} }
if v := req.Get("prompt_cache_key"); v.Exists() { if v := req.Get("prompt_cache_key"); v.Exists() {
out, _ = sjson.Set(out, "prompt_cache_key", v.String()) out, _ = sjson.SetBytes(out, "prompt_cache_key", v.String())
} }
if v := req.Get("reasoning"); v.Exists() { if v := req.Get("reasoning"); v.Exists() {
out, _ = sjson.Set(out, "reasoning", v.Value()) out, _ = sjson.SetBytes(out, "reasoning", v.Value())
} }
if v := req.Get("safety_identifier"); v.Exists() { if v := req.Get("safety_identifier"); v.Exists() {
out, _ = sjson.Set(out, "safety_identifier", v.String()) out, _ = sjson.SetBytes(out, "safety_identifier", v.String())
} }
if v := req.Get("service_tier"); v.Exists() { if v := req.Get("service_tier"); v.Exists() {
out, _ = sjson.Set(out, "service_tier", v.String()) out, _ = sjson.SetBytes(out, "service_tier", v.String())
} }
if v := req.Get("store"); v.Exists() { if v := req.Get("store"); v.Exists() {
out, _ = sjson.Set(out, "store", v.Bool()) out, _ = sjson.SetBytes(out, "store", v.Bool())
} }
if v := req.Get("temperature"); v.Exists() { if v := req.Get("temperature"); v.Exists() {
out, _ = sjson.Set(out, "temperature", v.Float()) out, _ = sjson.SetBytes(out, "temperature", v.Float())
} }
if v := req.Get("text"); v.Exists() { if v := req.Get("text"); v.Exists() {
out, _ = sjson.Set(out, "text", v.Value()) out, _ = sjson.SetBytes(out, "text", v.Value())
} }
if v := req.Get("tool_choice"); v.Exists() { if v := req.Get("tool_choice"); v.Exists() {
out, _ = sjson.Set(out, "tool_choice", v.Value()) out, _ = sjson.SetBytes(out, "tool_choice", v.Value())
} }
if v := req.Get("tools"); v.Exists() { if v := req.Get("tools"); v.Exists() {
out, _ = sjson.Set(out, "tools", v.Value()) out, _ = sjson.SetBytes(out, "tools", v.Value())
} }
if v := req.Get("top_logprobs"); v.Exists() { if v := req.Get("top_logprobs"); v.Exists() {
out, _ = sjson.Set(out, "top_logprobs", v.Int()) out, _ = sjson.SetBytes(out, "top_logprobs", v.Int())
} }
if v := req.Get("top_p"); v.Exists() { if v := req.Get("top_p"); v.Exists() {
out, _ = sjson.Set(out, "top_p", v.Float()) out, _ = sjson.SetBytes(out, "top_p", v.Float())
} }
if v := req.Get("truncation"); v.Exists() { if v := req.Get("truncation"); v.Exists() {
out, _ = sjson.Set(out, "truncation", v.String()) out, _ = sjson.SetBytes(out, "truncation", v.String())
} }
if v := req.Get("user"); v.Exists() { if v := req.Get("user"); v.Exists() {
out, _ = sjson.Set(out, "user", v.Value()) out, _ = sjson.SetBytes(out, "user", v.Value())
} }
if v := req.Get("metadata"); v.Exists() { if v := req.Get("metadata"); v.Exists() {
out, _ = sjson.Set(out, "metadata", v.Value()) out, _ = sjson.SetBytes(out, "metadata", v.Value())
} }
} }
// Build output array // Build output array
outputsWrapper := `{"arr":[]}` outputsWrapper := []byte(`{"arr":[]}`)
if reasoningBuf.Len() > 0 { if reasoningBuf.Len() > 0 {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}` item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.Set(item, "id", reasoningItemID) item, _ = sjson.SetBytes(item, "id", reasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String()) item, _ = sjson.SetBytes(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
if currentMsgID != "" || textBuf.Len() > 0 { if currentMsgID != "" || textBuf.Len() > 0 {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.Set(item, "id", currentMsgID) item, _ = sjson.SetBytes(item, "id", currentMsgID)
item, _ = sjson.Set(item, "content.0.text", textBuf.String()) item, _ = sjson.SetBytes(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
// Preserve index order // Preserve index order
@@ -659,28 +660,28 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" { if args == "" {
args = "{}" args = "{}"
} }
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}` item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id)) item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.Set(item, "arguments", args) item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.id) item, _ = sjson.SetBytes(item, "call_id", st.id)
item, _ = sjson.Set(item, "name", st.name) item, _ = sjson.SetBytes(item, "name", st.name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
} }
} }
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 { if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw) out, _ = sjson.SetRawBytes(out, "output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
} }
// Usage // Usage
total := inputTokens + outputTokens total := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", total) out, _ = sjson.SetBytes(out, "usage.total_tokens", total)
if reasoningBuf.Len() > 0 { if reasoningBuf.Len() > 0 {
// Rough estimate similar to chat completions // Rough estimate similar to chat completions
reasoningTokens := int64(len(reasoningBuf.String()) / 4) reasoningTokens := int64(len(reasoningBuf.String()) / 4)
if reasoningTokens > 0 { if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens) out, _ = sjson.SetBytes(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
} }
} }

View File

@@ -36,32 +36,41 @@ import (
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
template := `{"model":"","instructions":"","input":[]}` template := []byte(`{"model":"","instructions":"","input":[]}`)
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
template, _ = sjson.Set(template, "model", modelName) template, _ = sjson.SetBytes(template, "model", modelName)
// Process system messages and convert them to input content format. // Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system") systemsResult := rootResult.Get("system")
if systemsResult.IsArray() { if systemsResult.Exists() {
systemResults := systemsResult.Array() message := []byte(`{"type":"message","role":"developer","content":[]}`)
message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0 contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i] appendSystemText := func(text string) {
systemTypeResult := systemResult.Get("type") if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
if systemTypeResult.String() == "text" { return
text := systemResult.Get("text").String() }
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
} }
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
} }
} }
if contentIndex > 0 { if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message) template, _ = sjson.SetRawBytes(template, "input.-1", message)
} }
} }
@@ -74,9 +83,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
messageResult := messageResults[i] messageResult := messageResults[i]
messageRole := messageResult.Get("role").String() messageRole := messageResult.Get("role").String()
newMessage := func() string { newMessage := func() []byte {
msg := `{"type": "message","role":"","content":[]}` msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.Set(msg, "role", messageRole) msg, _ = sjson.SetBytes(msg, "role", messageRole)
return msg return msg
} }
@@ -86,7 +95,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
flushMessage := func() { flushMessage := func() {
if hasContent { if hasContent {
template, _ = sjson.SetRaw(template, "input.-1", message) template, _ = sjson.SetRawBytes(template, "input.-1", message)
message = newMessage() message = newMessage()
contentIndex = 0 contentIndex = 0
hasContent = false hasContent = false
@@ -98,15 +107,15 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if messageRole == "assistant" { if messageRole == "assistant" {
partType = "output_text" partType = "output_text"
} }
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType) message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++ contentIndex++
hasContent = true hasContent = true
} }
appendImageContent := func(dataURL string) { appendImageContent := func(dataURL string) {
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image") message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL) message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
contentIndex++ contentIndex++
hasContent = true hasContent = true
} }
@@ -142,8 +151,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
case "tool_use": case "tool_use":
flushMessage() flushMessage()
functionCallMessage := `{"type":"function_call"}` functionCallMessage := []byte(`{"type":"function_call"}`)
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String()) functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
{ {
name := messageContentResult.Get("name").String() name := messageContentResult.Get("name").String()
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON) toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
@@ -152,19 +161,19 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else { } else {
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name) functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "name", name)
} }
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw) functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage) template, _ = sjson.SetRawBytes(template, "input.-1", functionCallMessage)
case "tool_result": case "tool_result":
flushMessage() flushMessage()
functionCallOutputMessage := `{"type":"function_call_output"}` functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String()) functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
contentResult := messageContentResult.Get("content") contentResult := messageContentResult.Get("content")
if contentResult.IsArray() { if contentResult.IsArray() {
toolResultContentIndex := 0 toolResultContentIndex := 0
toolResultContent := `[]` toolResultContent := []byte(`[]`)
contentResults := contentResult.Array() contentResults := contentResult.Array()
for k := 0; k < len(contentResults); k++ { for k := 0; k < len(contentResults); k++ {
toolResultContentType := contentResults[k].Get("type").String() toolResultContentType := contentResults[k].Get("type").String()
@@ -185,27 +194,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data) dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image") toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL) toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContentIndex++ toolResultContentIndex++
} }
} }
} else if toolResultContentType == "text" { } else if toolResultContentType == "text" {
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text") toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String()) toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContentIndex++ toolResultContentIndex++
} }
} }
if toolResultContent != `[]` { if toolResultContentIndex > 0 {
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent) functionCallOutputMessage, _ = sjson.SetRawBytes(functionCallOutputMessage, "output", toolResultContent)
} else { } else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
} }
} else { } else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
} }
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage) template, _ = sjson.SetRawBytes(template, "input.-1", functionCallOutputMessage)
} }
} }
flushMessage() flushMessage()
@@ -220,8 +229,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Convert tools declarations to the expected format for the Codex API. // Convert tools declarations to the expected format for the Codex API.
toolsResult := rootResult.Get("tools") toolsResult := rootResult.Get("tools")
if toolsResult.IsArray() { if toolsResult.IsArray() {
template, _ = sjson.SetRaw(template, "tools", `[]`) template, _ = sjson.SetRawBytes(template, "tools", []byte(`[]`))
template, _ = sjson.Set(template, "tool_choice", `auto`) template, _ = sjson.SetBytes(template, "tool_choice", `auto`)
toolResults := toolsResult.Array() toolResults := toolsResult.Array()
// Build short name map from declared tools // Build short name map from declared tools
var names []string var names []string
@@ -237,11 +246,11 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Special handling: map Claude web search tool to Codex web_search // Special handling: map Claude web search tool to Codex web_search
if toolResult.Get("type").String() == "web_search_20250305" { if toolResult.Get("type").String() == "web_search_20250305" {
// Replace the tool content entirely with {"type":"web_search"} // Replace the tool content entirely with {"type":"web_search"}
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`) template, _ = sjson.SetRawBytes(template, "tools.-1", []byte(`{"type":"web_search"}`))
continue continue
} }
tool := toolResult.Raw tool := []byte(toolResult.Raw)
tool, _ = sjson.Set(tool, "type", "function") tool, _ = sjson.SetBytes(tool, "type", "function")
// Apply shortened name if needed // Apply shortened name if needed
if v := toolResult.Get("name"); v.Exists() { if v := toolResult.Get("name"); v.Exists() {
name := v.String() name := v.String()
@@ -250,20 +259,26 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else { } else {
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
tool, _ = sjson.Set(tool, "name", name) tool, _ = sjson.SetBytes(tool, "name", name)
} }
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw)) tool, _ = sjson.SetRawBytes(tool, "parameters", []byte(normalizeToolParameters(toolResult.Get("input_schema").Raw)))
tool, _ = sjson.Delete(tool, "input_schema") tool, _ = sjson.DeleteBytes(tool, "input_schema")
tool, _ = sjson.Delete(tool, "parameters.$schema") tool, _ = sjson.DeleteBytes(tool, "parameters.$schema")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading") tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.Set(tool, "strict", false) tool, _ = sjson.SetBytes(tool, "strict", false)
template, _ = sjson.SetRaw(template, "tools.-1", tool) template, _ = sjson.SetRawBytes(template, "tools.-1", tool)
} }
} }
// Default to parallel tool calls unless tool_choice explicitly disables them.
parallelToolCalls := true
if disableParallelToolUse := rootResult.Get("tool_choice.disable_parallel_tool_use"); disableParallelToolUse.Exists() {
parallelToolCalls = !disableParallelToolUse.Bool()
}
// Add additional configuration parameters for the Codex API. // Add additional configuration parameters for the Codex API.
template, _ = sjson.Set(template, "parallel_tool_calls", true) template, _ = sjson.SetBytes(template, "parallel_tool_calls", parallelToolCalls)
// Convert thinking.budget_tokens to reasoning.effort. // Convert thinking.budget_tokens to reasoning.effort.
reasoningEffort := "medium" reasoningEffort := "medium"
@@ -294,13 +309,13 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
} }
} }
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort) template, _ = sjson.SetBytes(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.Set(template, "reasoning.summary", "auto") template, _ = sjson.SetBytes(template, "reasoning.summary", "auto")
template, _ = sjson.Set(template, "stream", true) template, _ = sjson.SetBytes(template, "stream", true)
template, _ = sjson.Set(template, "store", false) template, _ = sjson.SetBytes(template, "store", false)
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"}) template, _ = sjson.SetBytes(template, "include", []string{"reasoning.encrypted_content"})
return []byte(template) return template
} }
// shortenNameIfNeeded applies a simple shortening rule for a single name. // shortenNameIfNeeded applies a simple shortening rule for a single name.
@@ -403,15 +418,15 @@ func normalizeToolParameters(raw string) string {
if raw == "" || raw == "null" || !gjson.Valid(raw) { if raw == "" || raw == "null" || !gjson.Valid(raw) {
return `{"type":"object","properties":{}}` return `{"type":"object","properties":{}}`
} }
schema := raw
result := gjson.Parse(raw) result := gjson.Parse(raw)
schema := []byte(raw)
schemaType := result.Get("type").String() schemaType := result.Get("type").String()
if schemaType == "" { if schemaType == "" {
schema, _ = sjson.Set(schema, "type", "object") schema, _ = sjson.SetBytes(schema, "type", "object")
schemaType = "object" schemaType = "object"
} }
if schemaType == "object" && !result.Get("properties").Exists() { if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRaw(schema, "properties", `{}`) schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
} }
return schema return string(schema)
} }

View File

@@ -0,0 +1,135 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}
func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantParallelToolCalls bool
}{
{
name: "Default to true when tool_choice.disable_parallel_tool_use is absent",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
{
name: "Disable parallel tool calls when client opts out",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": true},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: false,
},
{
name: "Keep parallel tool calls enabled when client explicitly allows them",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": false},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
if got := resultJSON.Get("parallel_tool_calls").Bool(); got != tt.wantParallelToolCalls {
t.Fatalf("parallel_tool_calls = %v, want %v. Output: %s", got, tt.wantParallelToolCalls, string(result))
}
})
}
}

View File

@@ -9,9 +9,9 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"strings" "strings"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -43,8 +43,8 @@ type ConvertCodexResponseToClaudeParams struct {
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response // - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{ *param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false, HasToolCall: false,
@@ -54,95 +54,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// log.Debugf("rawJSON: %s", string(rawJSON)) // log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
output := "" output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type") typeResult := rootResult.Get("type")
typeStr := typeResult.String() typeStr := typeResult.String()
template := "" var template []byte
if typeStr == "response.created" { if typeStr == "response.created" {
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}` template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String()) template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String()) template, _ = sjson.SetBytes(template, "message.id", rootResult.Get("response.id").String())
output = "event: message_start\n" output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.added" { } else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}` template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_text.delta" { } else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}` template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.done" { } else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}` template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.added" { } else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.output_text.delta" { } else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.done" { } else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}` template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.completed" { } else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
stopReason := rootResult.Get("response.stop_reason").String() stopReason := rootResult.Get("response.stop_reason").String()
if p { if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use") template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
} else if stopReason == "max_tokens" || stopReason == "stop" { } else if stopReason == "max_tokens" || stopReason == "stop" {
template, _ = sjson.Set(template, "delta.stop_reason", stopReason) template, _ = sjson.SetBytes(template, "delta.stop_reason", stopReason)
} else { } else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn") template, _ = sjson.SetBytes(template, "delta.stop_reason", "end_turn")
} }
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage")) inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens) template, _ = sjson.SetBytes(template, "usage.input_tokens", inputTokens)
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens) template, _ = sjson.SetBytes(template, "usage.output_tokens", outputTokens)
if cachedTokens > 0 { if cachedTokens > 0 {
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens) template, _ = sjson.SetBytes(template, "usage.cache_read_input_tokens", cachedTokens)
} }
output = "event: message_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "message_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template) output = translatorcommon.AppendSSEEventBytes(output, "message_stop", []byte(`{"type":"message_stop"}`), 2)
output += "event: message_stop\n"
output += `data: {"type":"message_stop"}`
output += "\n\n"
} else if typeStr == "response.output_item.added" { } else if typeStr == "response.output_item.added" {
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String() itemType := itemResult.Get("type").String()
if itemType == "function_call" { if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String())) template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{ {
// Restore original tool name if shortened // Restore original tool name if shortened
name := itemResult.Get("name").String() name := itemResult.Get("name").String()
@@ -150,37 +140,33 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
if orig, ok := rev[name]; ok { if orig, ok := rev[name]; ok {
name = orig name = orig
} }
template, _ = sjson.Set(template, "content_block.name", name) template, _ = sjson.SetBytes(template, "content_block.name", name)
} }
output = "event: content_block_start\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} }
} else if typeStr == "response.output_item.done" { } else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String() itemType := itemResult.Get("type").String()
if itemType == "function_call" { if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}` template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} }
} else if typeStr == "response.function_call_arguments.delta" { } else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String()) template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.function_call_arguments.done" { } else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments // Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events. // in a single "done" event without preceding "delta" events.
@@ -189,17 +175,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// When delta events were already received, skip to avoid duplicating arguments. // When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta { if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" { if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args) template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output += "event: content_block_delta\n" output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
output += fmt.Sprintf("data: %s\n\n", template)
} }
} }
} }
return []string{output} return [][]byte{output}
} }
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response. // ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
@@ -214,28 +199,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: A Claude Code-compatible JSON response containing all message content and metadata // - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string { func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
if rootResult.Get("type").String() != "response.completed" { if rootResult.Get("type").String() != "response.completed" {
return "" return []byte{}
} }
responseData := rootResult.Get("response") responseData := rootResult.Get("response")
if !responseData.Exists() { if !responseData.Exists() {
return "" return []byte{}
} }
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.Set(out, "id", responseData.Get("id").String()) out, _ = sjson.SetBytes(out, "id", responseData.Get("id").String())
out, _ = sjson.Set(out, "model", responseData.Get("model").String()) out, _ = sjson.SetBytes(out, "model", responseData.Get("model").String())
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage")) inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 { if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens) out, _ = sjson.SetBytes(out, "usage.cache_read_input_tokens", cachedTokens)
} }
hasToolCall := false hasToolCall := false
@@ -276,9 +261,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} }
} }
if thinkingBuilder.Len() > 0 { if thinkingBuilder.Len() > 0 {
block := `{"type":"thinking","thinking":""}` block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String()) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
} }
case "message": case "message":
if content := item.Get("content"); content.Exists() { if content := item.Get("content"); content.Exists() {
@@ -287,9 +272,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" { if part.Get("type").String() == "output_text" {
text := part.Get("text").String() text := part.Get("text").String()
if text != "" { if text != "" {
block := `{"type":"text","text":""}` block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.Set(block, "text", text) block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
} }
} }
return true return true
@@ -297,9 +282,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else { } else {
text := content.String() text := content.String()
if text != "" { if text != "" {
block := `{"type":"text","text":""}` block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.Set(block, "text", text) block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
} }
} }
} }
@@ -310,9 +295,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original name = original
} }
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String())) toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}" inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) { if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr) argsJSON := gjson.Parse(argsStr)
@@ -320,23 +305,23 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
inputRaw = argsJSON.Raw inputRaw = argsJSON.Raw
} }
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw) toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRaw(out, "content.-1", toolBlock) out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
} }
return true return true
}) })
} }
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" { if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
out, _ = sjson.Set(out, "stop_reason", stopReason.String()) out, _ = sjson.SetBytes(out, "stop_reason", stopReason.String())
} else if hasToolCall { } else if hasToolCall {
out, _ = sjson.Set(out, "stop_reason", "tool_use") out, _ = sjson.SetBytes(out, "stop_reason", "tool_use")
} else { } else {
out, _ = sjson.Set(out, "stop_reason", "end_turn") out, _ = sjson.SetBytes(out, "stop_reason", "end_turn")
} }
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" { if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw) out, _ = sjson.SetRawBytes(out, "stop_sequence", []byte(stopSequence.Raw))
} }
return out return out
@@ -386,6 +371,6 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev return rev
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"input_tokens":%d}`, count) return translatorcommon.ClaudeInputTokensJSON(count)
} }

View File

@@ -6,10 +6,9 @@ package geminiCLI
import ( import (
"context" "context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/sjson" translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
) )
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format. // ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
@@ -24,14 +23,12 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object // - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
newOutputs := make([]string, 0) newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ { for i := 0; i < len(outputs); i++ {
json := `{"response": {}}` newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
} }
return newOutputs return newOutputs
} }
@@ -47,15 +44,12 @@ func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, orig
// - param: A pointer to a parameter object for the conversion // - param: A pointer to a parameter object for the conversion
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object // - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
// log.Debug(string(rawJSON)) out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param) return translatorcommon.WrapGeminiCLIResponse(out)
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
} }
func GeminiCLITokenCount(ctx context.Context, count int64) string { func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return translatorcommon.GeminiTokenCountJSON(count)
} }

View File

@@ -38,7 +38,7 @@ import (
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
// Base template // Base template
out := `{"model":"","instructions":"","input":[]}` out := []byte(`{"model":"","instructions":"","input":[]}`)
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -82,24 +82,24 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
// Model // Model
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// System instruction -> as a user message with input_text parts // System instruction -> as a user message with input_text parts
sysParts := root.Get("system_instruction.parts") sysParts := root.Get("system_instruction.parts")
if sysParts.IsArray() { if sysParts.IsArray() {
msg := `{"type":"message","role":"developer","content":[]}` msg := []byte(`{"type":"message","role":"developer","content":[]}`)
arr := sysParts.Array() arr := sysParts.Array()
for i := 0; i < len(arr); i++ { for i := 0; i < len(arr); i++ {
p := arr[i] p := arr[i]
if t := p.Get("text"); t.Exists() { if t := p.Get("text"); t.Exists() {
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", "input_text") part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.Set(part, "text", t.String()) part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} }
} }
if len(gjson.Get(msg, "content").Array()) > 0 { if len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRaw(out, "input.-1", msg) out, _ = sjson.SetRawBytes(out, "input.-1", msg)
} }
} }
@@ -123,23 +123,23 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
p := parr[j] p := parr[j]
// text part // text part
if t := p.Get("text"); t.Exists() { if t := p.Get("text"); t.Exists() {
msg := `{"type":"message","role":"","content":[]}` msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
partType := "input_text" partType := "input_text"
if role == "assistant" { if role == "assistant" {
partType = "output_text" partType = "output_text"
} }
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", partType) part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.Set(part, "text", t.String()) part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
out, _ = sjson.SetRaw(out, "input.-1", msg) out, _ = sjson.SetRawBytes(out, "input.-1", msg)
continue continue
} }
// function call from model // function call from model
if fc := p.Get("functionCall"); fc.Exists() { if fc := p.Get("functionCall"); fc.Exists() {
fn := `{"type":"function_call"}` fn := []byte(`{"type":"function_call"}`)
if name := fc.Get("name"); name.Exists() { if name := fc.Get("name"); name.Exists() {
n := name.String() n := name.String()
if short, ok := shortMap[n]; ok { if short, ok := shortMap[n]; ok {
@@ -147,31 +147,31 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else { } else {
n = shortenNameIfNeeded(n) n = shortenNameIfNeeded(n)
} }
fn, _ = sjson.Set(fn, "name", n) fn, _ = sjson.SetBytes(fn, "name", n)
} }
if args := fc.Get("args"); args.Exists() { if args := fc.Get("args"); args.Exists() {
fn, _ = sjson.Set(fn, "arguments", args.Raw) fn, _ = sjson.SetBytes(fn, "arguments", args.Raw)
} }
// generate a paired random call_id and enqueue it so the // generate a paired random call_id and enqueue it so the
// corresponding functionResponse can pop the earliest id // corresponding functionResponse can pop the earliest id
// to preserve ordering when multiple calls are present. // to preserve ordering when multiple calls are present.
id := genCallID() id := genCallID()
fn, _ = sjson.Set(fn, "call_id", id) fn, _ = sjson.SetBytes(fn, "call_id", id)
pendingCallIDs = append(pendingCallIDs, id) pendingCallIDs = append(pendingCallIDs, id)
out, _ = sjson.SetRaw(out, "input.-1", fn) out, _ = sjson.SetRawBytes(out, "input.-1", fn)
continue continue
} }
// function response from user // function response from user
if fr := p.Get("functionResponse"); fr.Exists() { if fr := p.Get("functionResponse"); fr.Exists() {
fno := `{"type":"function_call_output"}` fno := []byte(`{"type":"function_call_output"}`)
// Prefer a string result if present; otherwise embed the raw response as a string // Prefer a string result if present; otherwise embed the raw response as a string
if res := fr.Get("response.result"); res.Exists() { if res := fr.Get("response.result"); res.Exists() {
fno, _ = sjson.Set(fno, "output", res.String()) fno, _ = sjson.SetBytes(fno, "output", res.String())
} else if resp := fr.Get("response"); resp.Exists() { } else if resp := fr.Get("response"); resp.Exists() {
fno, _ = sjson.Set(fno, "output", resp.Raw) fno, _ = sjson.SetBytes(fno, "output", resp.Raw)
} }
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq") // fno, _ = sjson.SetBytes(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// attach the oldest queued call_id to pair the response // attach the oldest queued call_id to pair the response
// with its call. If the queue is empty, generate a new id. // with its call. If the queue is empty, generate a new id.
var id string var id string
@@ -182,8 +182,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else { } else {
id = genCallID() id = genCallID()
} }
fno, _ = sjson.Set(fno, "call_id", id) fno, _ = sjson.SetBytes(fno, "call_id", id)
out, _ = sjson.SetRaw(out, "input.-1", fno) out, _ = sjson.SetRawBytes(out, "input.-1", fno)
continue continue
} }
} }
@@ -193,8 +193,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Tools mapping: Gemini functionDeclarations -> Codex tools // Tools mapping: Gemini functionDeclarations -> Codex tools
tools := root.Get("tools") tools := root.Get("tools")
if tools.IsArray() { if tools.IsArray() {
out, _ = sjson.SetRaw(out, "tools", `[]`) out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
out, _ = sjson.Set(out, "tool_choice", "auto") out, _ = sjson.SetBytes(out, "tool_choice", "auto")
tarr := tools.Array() tarr := tools.Array()
for i := 0; i < len(tarr); i++ { for i := 0; i < len(tarr); i++ {
td := tarr[i] td := tarr[i]
@@ -205,8 +205,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
farr := fns.Array() farr := fns.Array()
for j := 0; j < len(farr); j++ { for j := 0; j < len(farr); j++ {
fn := farr[j] fn := farr[j]
tool := `{}` tool := []byte(`{}`)
tool, _ = sjson.Set(tool, "type", "function") tool, _ = sjson.SetBytes(tool, "type", "function")
if v := fn.Get("name"); v.Exists() { if v := fn.Get("name"); v.Exists() {
name := v.String() name := v.String()
if short, ok := shortMap[name]; ok { if short, ok := shortMap[name]; ok {
@@ -214,32 +214,32 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else { } else {
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
tool, _ = sjson.Set(tool, "name", name) tool, _ = sjson.SetBytes(tool, "name", name)
} }
if v := fn.Get("description"); v.Exists() { if v := fn.Get("description"); v.Exists() {
tool, _ = sjson.Set(tool, "description", v.String()) tool, _ = sjson.SetBytes(tool, "description", v.String())
} }
if prm := fn.Get("parameters"); prm.Exists() { if prm := fn.Get("parameters"); prm.Exists() {
// Remove optional $schema field if present // Remove optional $schema field if present
cleaned := prm.Raw cleaned := []byte(prm.Raw)
cleaned, _ = sjson.Delete(cleaned, "$schema") cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false) cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned) tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() { } else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
// Remove optional $schema field if present // Remove optional $schema field if present
cleaned := prm.Raw cleaned := []byte(prm.Raw)
cleaned, _ = sjson.Delete(cleaned, "$schema") cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false) cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned) tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
} }
tool, _ = sjson.Set(tool, "strict", false) tool, _ = sjson.SetBytes(tool, "strict", false)
out, _ = sjson.SetRaw(out, "tools.-1", tool) out, _ = sjson.SetRawBytes(out, "tools.-1", tool)
} }
} }
} }
// Fixed flags aligning with Codex expectations // Fixed flags aligning with Codex expectations
out, _ = sjson.Set(out, "parallel_tool_calls", true) out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
// Convert Gemini thinkingConfig to Codex reasoning.effort. // Convert Gemini thinkingConfig to Codex reasoning.effort.
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget). // Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if thinkingLevel.Exists() { if thinkingLevel.Exists() {
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String())) effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if effort != "" { if effort != "" {
out, _ = sjson.Set(out, "reasoning.effort", effort) out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true effortSet = true
} }
} else { } else {
@@ -263,7 +263,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
if thinkingBudget.Exists() { if thinkingBudget.Exists() {
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok { if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
out, _ = sjson.Set(out, "reasoning.effort", effort) out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true effortSet = true
} }
} }
@@ -272,22 +272,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} }
if !effortSet { if !effortSet {
// No thinking config, set default effort // No thinking config, set default effort
out, _ = sjson.Set(out, "reasoning.effort", "medium") out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
} }
out, _ = sjson.Set(out, "reasoning.summary", "auto") out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "stream", true) out, _ = sjson.SetBytes(out, "stream", true)
out, _ = sjson.Set(out, "store", false) out, _ = sjson.SetBytes(out, "store", false)
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"}) out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
var pathsToLower []string var pathsToLower []string
toolsResult := gjson.Get(out, "tools") toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower) util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower { for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p) fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String())) out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
} }
return []byte(out) return out
} }
// shortenNameIfNeeded applies the simple shortening rule for a single name. // shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -7,9 +7,9 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"time" "time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -23,7 +23,7 @@ type ConvertCodexResponseToGeminiParams struct {
Model string Model string
CreatedAt int64 CreatedAt int64
ResponseID string ResponseID string
LastStorageOutput string LastStorageOutput []byte
} }
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format. // ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
@@ -38,19 +38,19 @@ type ConvertCodexResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response // - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertCodexResponseToGeminiParams{ *param = &ConvertCodexResponseToGeminiParams{
Model: modelName, Model: modelName,
CreatedAt: 0, CreatedAt: 0,
ResponseID: "", ResponseID: "",
LastStorageOutput: "", LastStorageOutput: nil,
} }
} }
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -59,17 +59,17 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
typeStr := typeResult.String() typeStr := typeResult.String()
// Base Gemini response template // Base Gemini response template
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}` template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" { if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
} else { } else {
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model) template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
createdAtResult := rootResult.Get("response.created_at") createdAtResult := rootResult.Get("response.created_at")
if createdAtResult.Exists() { if createdAtResult.Exists() {
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int() (*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano)) template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
} }
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID) template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
} }
// Handle function call completion // Handle function call completion
@@ -78,7 +78,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
itemType := itemResult.Get("type").String() itemType := itemResult.Get("type").String()
if itemType == "function_call" { if itemType == "function_call" {
// Create function call part // Create function call part
functionCall := `{"functionCall":{"name":"","args":{}}}` functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
{ {
// Restore original tool name if shortened // Restore original tool name if shortened
n := itemResult.Get("name").String() n := itemResult.Get("name").String()
@@ -86,7 +86,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if orig, ok := rev[n]; ok { if orig, ok := rev[n]; ok {
n = orig n = orig
} }
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n) functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
} }
// Parse and set arguments // Parse and set arguments
@@ -94,47 +94,48 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if argsStr != "" { if argsStr != "" {
argsResult := gjson.Parse(argsStr) argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() { if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr) functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
} }
} }
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// Use this return to storage message // Use this return to storage message
return []string{} return [][]byte{}
} }
} }
if typeStr == "response.created" { // Handle response creation - set model and response ID if typeStr == "response.created" { // Handle response creation - set model and response ID
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String()) template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String()) template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String() (*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta } else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
part := `{"thought":true,"text":""}` part := []byte(`{"thought":true,"text":""}`)
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String()) part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta } else if typeStr == "response.output_text.delta" { // Handle regular text content delta
part := `{"text":""}` part := []byte(`{"text":""}`)
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String()) part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.completed" { // Handle response completion with usage metadata } else if typeStr == "response.completed" { // Handle response completion with usage metadata
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int()) template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int()) template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int() totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens) template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
} else { } else {
return []string{} return [][]byte{}
} }
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" { if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template} return [][]byte{
} else { append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
return []string{template} template,
}
} }
return [][]byte{template}
} }
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response. // ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
@@ -149,32 +150,32 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata // - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event // Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" { if rootResult.Get("type").String() != "response.completed" {
return "" return []byte{}
} }
// Base Gemini response template for non-streaming // Base Gemini response template for non-streaming
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}` template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version // Set model version
template, _ = sjson.Set(template, "modelVersion", modelName) template, _ = sjson.SetBytes(template, "modelVersion", modelName)
// Set response metadata from the completed response // Set response metadata from the completed response
responseData := rootResult.Get("response") responseData := rootResult.Get("response")
if responseData.Exists() { if responseData.Exists() {
// Set response ID // Set response ID
if responseId := responseData.Get("id"); responseId.Exists() { if responseId := responseData.Get("id"); responseId.Exists() {
template, _ = sjson.Set(template, "responseId", responseId.String()) template, _ = sjson.SetBytes(template, "responseId", responseId.String())
} }
// Set creation time // Set creation time
if createdAt := responseData.Get("created_at"); createdAt.Exists() { if createdAt := responseData.Get("created_at"); createdAt.Exists() {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano)) template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
} }
// Set usage metadata // Set usage metadata
@@ -183,14 +184,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
outputTokens := usage.Get("output_tokens").Int() outputTokens := usage.Get("output_tokens").Int()
totalTokens := inputTokens + outputTokens totalTokens := inputTokens + outputTokens
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens) template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens) template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens) template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
} }
// Process output content to build parts array // Process output content to build parts array
hasToolCall := false hasToolCall := false
var pendingFunctionCalls []string var pendingFunctionCalls [][]byte
flushPendingFunctionCalls := func() { flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) == 0 { if len(pendingFunctionCalls) == 0 {
@@ -199,7 +200,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add all pending function calls as individual parts // Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together // This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls { for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", fc)
} }
pendingFunctionCalls = nil pendingFunctionCalls = nil
} }
@@ -215,9 +216,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content // Add thinking content
if content := value.Get("content"); content.Exists() { if content := value.Get("content"); content.Exists() {
part := `{"text":"","thought":true}` part := []byte(`{"text":"","thought":true}`)
part, _ = sjson.Set(part, "text", content.String()) part, _ = sjson.SetBytes(part, "text", content.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} }
case "message": case "message":
@@ -229,9 +230,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool { content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" { if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() { if text := contentItem.Get("text"); text.Exists() {
part := `{"text":""}` part := []byte(`{"text":""}`)
part, _ = sjson.Set(part, "text", text.String()) part, _ = sjson.SetBytes(part, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} }
} }
return true return true
@@ -241,21 +242,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call": case "function_call":
// Collect function call for potential merging with consecutive ones // Collect function call for potential merging with consecutive ones
hasToolCall = true hasToolCall = true
functionCall := `{"functionCall":{"args":{},"name":""}}` functionCall := []byte(`{"functionCall":{"args":{},"name":""}}`)
{ {
n := value.Get("name").String() n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON) rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok { if orig, ok := rev[n]; ok {
n = orig n = orig
} }
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n) functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
} }
// Parse and set arguments // Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" { if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr) argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() { if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr) functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
} }
} }
@@ -270,9 +271,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Set finish reason based on whether there were tool calls // Set finish reason based on whether there were tool calls
if hasToolCall { if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
} else { } else {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
} }
} }
return template return template
@@ -307,6 +308,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev return rev
} }
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return translatorcommon.GeminiTokenCountJSON(count)
} }

View File

@@ -29,42 +29,42 @@ import (
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte { func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
// Start with empty JSON object // Start with empty JSON object
out := `{"instructions":""}` out := []byte(`{"instructions":""}`)
// Stream must be set to true // Stream must be set to true
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.SetBytes(out, "stream", stream)
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them // Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() { // if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
// out, _ = sjson.Set(out, "temperature", v.Value()) // out, _ = sjson.SetBytes(out, "temperature", v.Value())
// } // }
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() { // if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
// out, _ = sjson.Set(out, "top_p", v.Value()) // out, _ = sjson.SetBytes(out, "top_p", v.Value())
// } // }
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() { // if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
// out, _ = sjson.Set(out, "top_k", v.Value()) // out, _ = sjson.SetBytes(out, "top_k", v.Value())
// } // }
// Map token limits // Map token limits
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() { // if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value()) // out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// } // }
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() { // if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value()) // out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// } // }
// Map reasoning effort // Map reasoning effort
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() { if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
out, _ = sjson.Set(out, "reasoning.effort", v.Value()) out, _ = sjson.SetBytes(out, "reasoning.effort", v.Value())
} else { } else {
out, _ = sjson.Set(out, "reasoning.effort", "medium") out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
} }
out, _ = sjson.Set(out, "parallel_tool_calls", true) out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
out, _ = sjson.Set(out, "reasoning.summary", "auto") out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"}) out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
// Model // Model
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// Build tool name shortening map from original tools (if any) // Build tool name shortening map from original tools (if any)
originalToolNameMap := map[string]string{} originalToolNameMap := map[string]string{}
@@ -100,9 +100,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// if m.Get("role").String() == "system" { // if m.Get("role").String() == "system" {
// c := m.Get("content") // c := m.Get("content")
// if c.Type == gjson.String { // if c.Type == gjson.String {
// out, _ = sjson.Set(out, "instructions", c.String()) // out, _ = sjson.SetBytes(out, "instructions", c.String())
// } else if c.IsObject() && c.Get("type").String() == "text" { // } else if c.IsObject() && c.Get("type").String() == "text" {
// out, _ = sjson.Set(out, "instructions", c.Get("text").String()) // out, _ = sjson.SetBytes(out, "instructions", c.Get("text").String())
// } // }
// break // break
// } // }
@@ -110,7 +110,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// } // }
// Build input from messages, handling all message types including tool calls // Build input from messages, handling all message types including tool calls
out, _ = sjson.SetRaw(out, "input", `[]`) out, _ = sjson.SetRawBytes(out, "input", []byte(`[]`))
if messages.IsArray() { if messages.IsArray() {
arr := messages.Array() arr := messages.Array()
for i := 0; i < len(arr); i++ { for i := 0; i < len(arr); i++ {
@@ -124,23 +124,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
content := m.Get("content").String() content := m.Get("content").String()
// Create function_call_output object // Create function_call_output object
funcOutput := `{}` funcOutput := []byte(`{}`)
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output") funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID) funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.Set(funcOutput, "output", content) funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
out, _ = sjson.SetRaw(out, "input.-1", funcOutput) out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default: default:
// Handle regular messages // Handle regular messages
msg := `{}` msg := []byte(`{}`)
msg, _ = sjson.Set(msg, "type", "message") msg, _ = sjson.SetBytes(msg, "type", "message")
if role == "system" { if role == "system" {
msg, _ = sjson.Set(msg, "role", "developer") msg, _ = sjson.SetBytes(msg, "role", "developer")
} else { } else {
msg, _ = sjson.Set(msg, "role", role) msg, _ = sjson.SetBytes(msg, "role", role)
} }
msg, _ = sjson.SetRaw(msg, "content", `[]`) msg, _ = sjson.SetRawBytes(msg, "content", []byte(`[]`))
// Handle regular content // Handle regular content
c := m.Get("content") c := m.Get("content")
@@ -150,10 +150,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" { if role == "assistant" {
partType = "output_text" partType = "output_text"
} }
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", partType) part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.Set(part, "text", c.String()) part, _ = sjson.SetBytes(part, "text", c.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if c.Exists() && c.IsArray() { } else if c.Exists() && c.IsArray() {
items := c.Array() items := c.Array()
for j := 0; j < len(items); j++ { for j := 0; j < len(items); j++ {
@@ -165,39 +165,44 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" { if role == "assistant" {
partType = "output_text" partType = "output_text"
} }
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", partType) part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.Set(part, "text", it.Get("text").String()) part, _ = sjson.SetBytes(part, "text", it.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
case "image_url": case "image_url":
// Map image inputs to input_image for Responses API // Map image inputs to input_image for Responses API
if role == "user" { if role == "user" {
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", "input_image") part, _ = sjson.SetBytes(part, "type", "input_image")
if u := it.Get("image_url.url"); u.Exists() { if u := it.Get("image_url.url"); u.Exists() {
part, _ = sjson.Set(part, "image_url", u.String()) part, _ = sjson.SetBytes(part, "image_url", u.String())
} }
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} }
case "file": case "file":
if role == "user" { if role == "user" {
fileData := it.Get("file.file_data").String() fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String() filename := it.Get("file.filename").String()
if fileData != "" { if fileData != "" {
part := `{}` part := []byte(`{}`)
part, _ = sjson.Set(part, "type", "input_file") part, _ = sjson.SetBytes(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData) part, _ = sjson.SetBytes(part, "file_data", fileData)
if filename != "" { if filename != "" {
part, _ = sjson.Set(part, "filename", filename) part, _ = sjson.SetBytes(part, "filename", filename)
} }
msg, _ = sjson.SetRaw(msg, "content.-1", part) msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} }
} }
} }
} }
} }
out, _ = sjson.SetRaw(out, "input.-1", msg) // Don't emit empty assistant messages when only tool_calls
// are present — Responses API needs function_call items
// directly, otherwise call_id matching fails (#2132).
if role != "assistant" || len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
}
// Handle tool calls for assistant messages as separate top-level objects // Handle tool calls for assistant messages as separate top-level objects
if role == "assistant" { if role == "assistant" {
@@ -208,9 +213,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
tc := toolCallsArr[j] tc := toolCallsArr[j]
if tc.Get("type").String() == "function" { if tc.Get("type").String() == "function" {
// Create function_call as top-level object // Create function_call as top-level object
funcCall := `{}` funcCall := []byte(`{}`)
funcCall, _ = sjson.Set(funcCall, "type", "function_call") funcCall, _ = sjson.SetBytes(funcCall, "type", "function_call")
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String()) funcCall, _ = sjson.SetBytes(funcCall, "call_id", tc.Get("id").String())
{ {
name := tc.Get("function.name").String() name := tc.Get("function.name").String()
if short, ok := originalToolNameMap[name]; ok { if short, ok := originalToolNameMap[name]; ok {
@@ -218,10 +223,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else { } else {
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
funcCall, _ = sjson.Set(funcCall, "name", name) funcCall, _ = sjson.SetBytes(funcCall, "name", name)
} }
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String()) funcCall, _ = sjson.SetBytes(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRaw(out, "input.-1", funcCall) out, _ = sjson.SetRawBytes(out, "input.-1", funcCall)
} }
} }
} }
@@ -235,26 +240,26 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
text := gjson.GetBytes(rawJSON, "text") text := gjson.GetBytes(rawJSON, "text")
if rf.Exists() { if rf.Exists() {
// Always create text object when response_format provided // Always create text object when response_format provided
if !gjson.Get(out, "text").Exists() { if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`) out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
} }
rft := rf.Get("type").String() rft := rf.Get("type").String()
switch rft { switch rft {
case "text": case "text":
out, _ = sjson.Set(out, "text.format.type", "text") out, _ = sjson.SetBytes(out, "text.format.type", "text")
case "json_schema": case "json_schema":
js := rf.Get("json_schema") js := rf.Get("json_schema")
if js.Exists() { if js.Exists() {
out, _ = sjson.Set(out, "text.format.type", "json_schema") out, _ = sjson.SetBytes(out, "text.format.type", "json_schema")
if v := js.Get("name"); v.Exists() { if v := js.Get("name"); v.Exists() {
out, _ = sjson.Set(out, "text.format.name", v.Value()) out, _ = sjson.SetBytes(out, "text.format.name", v.Value())
} }
if v := js.Get("strict"); v.Exists() { if v := js.Get("strict"); v.Exists() {
out, _ = sjson.Set(out, "text.format.strict", v.Value()) out, _ = sjson.SetBytes(out, "text.format.strict", v.Value())
} }
if v := js.Get("schema"); v.Exists() { if v := js.Get("schema"); v.Exists() {
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw) out, _ = sjson.SetRawBytes(out, "text.format.schema", []byte(v.Raw))
} }
} }
} }
@@ -262,23 +267,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Map verbosity if provided // Map verbosity if provided
if text.Exists() { if text.Exists() {
if v := text.Get("verbosity"); v.Exists() { if v := text.Get("verbosity"); v.Exists() {
out, _ = sjson.Set(out, "text.verbosity", v.Value()) out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
} }
} }
} else if text.Exists() { } else if text.Exists() {
// If only text.verbosity present (no response_format), map verbosity // If only text.verbosity present (no response_format), map verbosity
if v := text.Get("verbosity"); v.Exists() { if v := text.Get("verbosity"); v.Exists() {
if !gjson.Get(out, "text").Exists() { if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`) out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
} }
out, _ = sjson.Set(out, "text.verbosity", v.Value()) out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
} }
} }
// Map tools (flatten function fields) // Map tools (flatten function fields)
tools := gjson.GetBytes(rawJSON, "tools") tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 { if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", `[]`) out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
arr := tools.Array() arr := tools.Array()
for i := 0; i < len(arr); i++ { for i := 0; i < len(arr); i++ {
t := arr[i] t := arr[i]
@@ -286,13 +291,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API. // Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function". // Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() { if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw) out, _ = sjson.SetRawBytes(out, "tools.-1", []byte(t.Raw))
continue continue
} }
if toolType == "function" { if toolType == "function" {
item := `{}` item := []byte(`{}`)
item, _ = sjson.Set(item, "type", "function") item, _ = sjson.SetBytes(item, "type", "function")
fn := t.Get("function") fn := t.Get("function")
if fn.Exists() { if fn.Exists() {
if v := fn.Get("name"); v.Exists() { if v := fn.Get("name"); v.Exists() {
@@ -302,19 +307,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else { } else {
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
item, _ = sjson.Set(item, "name", name) item, _ = sjson.SetBytes(item, "name", name)
} }
if v := fn.Get("description"); v.Exists() { if v := fn.Get("description"); v.Exists() {
item, _ = sjson.Set(item, "description", v.Value()) item, _ = sjson.SetBytes(item, "description", v.Value())
} }
if v := fn.Get("parameters"); v.Exists() { if v := fn.Get("parameters"); v.Exists() {
item, _ = sjson.SetRaw(item, "parameters", v.Raw) item, _ = sjson.SetRawBytes(item, "parameters", []byte(v.Raw))
} }
if v := fn.Get("strict"); v.Exists() { if v := fn.Get("strict"); v.Exists() {
item, _ = sjson.Set(item, "strict", v.Value()) item, _ = sjson.SetBytes(item, "strict", v.Value())
} }
} }
out, _ = sjson.SetRaw(out, "tools.-1", item) out, _ = sjson.SetRawBytes(out, "tools.-1", item)
} }
} }
} }
@@ -325,7 +330,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() { if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch { switch {
case tc.Type == gjson.String: case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String()) out, _ = sjson.SetBytes(out, "tool_choice", tc.String())
case tc.IsObject(): case tc.IsObject():
tcType := tc.Get("type").String() tcType := tc.Get("type").String()
if tcType == "function" { if tcType == "function" {
@@ -337,21 +342,21 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
name = shortenNameIfNeeded(name) name = shortenNameIfNeeded(name)
} }
} }
choice := `{}` choice := []byte(`{}`)
choice, _ = sjson.Set(choice, "type", "function") choice, _ = sjson.SetBytes(choice, "type", "function")
if name != "" { if name != "" {
choice, _ = sjson.Set(choice, "name", name) choice, _ = sjson.SetBytes(choice, "name", name)
} }
out, _ = sjson.SetRaw(out, "tool_choice", choice) out, _ = sjson.SetRawBytes(out, "tool_choice", choice)
} else if tcType != "" { } else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible. // Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw) out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(tc.Raw))
} }
} }
} }
out, _ = sjson.Set(out, "store", false) out, _ = sjson.SetBytes(out, "store", false)
return []byte(out) return out
} }
// shortenNameIfNeeded applies the simple shortening rule for a single name. // shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -0,0 +1,635 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
// Expects developer msg + user msg + function_call + function_call_output.
// No empty assistant message should appear between user and function_call.
func TestToolCallSimple(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "sunny, 22C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
// system -> developer
if items[0].Get("type").String() != "message" {
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
}
if items[0].Get("role").String() != "developer" {
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
}
// user
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "user" {
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
}
// function_call, not an empty assistant msg
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_1" {
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
}
if items[2].Get("name").String() != "get_weather" {
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
}
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
}
// function_call_output
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_1" {
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
}
if items[3].Get("output").String() != "sunny, 22C" {
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
}
}
// Assistant has both text content and tool_calls — the message should
// be emitted (non-empty content), followed by function_call items.
func TestToolCallWithContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "What is the weather?"},
{
"role": "assistant",
"content": "Let me check the weather for you.",
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc",
"content": "rainy, 15C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + assistant(with content) + function_call + function_call_output
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
// assistant with content — should be kept
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "assistant" {
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
}
contentParts := items[1].Get("content").Array()
if len(contentParts) == 0 {
t.Errorf("item 1: assistant message should have content parts")
}
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_abc" {
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
}
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_abc" {
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_paris",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
},
{
"id": "call_london",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"London\"}"
}
},
{
"id": "call_tokyo",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Tokyo\"}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + 3 function_call + 3 function_call_output = 7
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
for i, expectedID := range expectedCallIDs {
idx := i + 1
if items[idx].Get("type").String() != "function_call" {
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedID {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
}
}
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
for i, expectedOutput := range expectedOutputs {
idx := i + 4
if items[idx].Get("type").String() != "function_call_output" {
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
}
if items[idx].Get("output").String() != expectedOutput {
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
}
}
}
// Regression test for #2132: tool-call-only assistant messages (content:null)
// must not produce an empty message item in the translated output.
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Call a tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_x",
"type": "function",
"function": {"name": "do_thing", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
],
"tools": [
{
"type": "function",
"function": {
"name": "do_thing",
"description": "Do a thing",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
typ := item.Get("type").String()
role := item.Get("role").String()
if typ == "message" && role == "assistant" {
contentArr := item.Get("content").Array()
if len(contentArr) == 0 {
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
}
}
}
// should be exactly: user + function_call + function_call_output
if len(items) != 3 {
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected user message")
}
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
}
// Two rounds of tool calling in one conversation, with a text reply in between.
func TestMultiTurnToolCalling(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
{"role": "assistant", "content": "It is sunny in Paris."},
{"role": "user", "content": "And London?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: unexpected empty assistant message", i)
}
}
}
// round 1
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[1].Get("call_id").String() != "call_r1" {
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
// text reply between rounds
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
}
// round 2
if items[5].Get("type").String() != "function_call" {
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
}
if items[5].Get("call_id").String() != "call_r2" {
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
}
if items[6].Get("type").String() != "function_call_output" {
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
}
}
// Tool names over 64 chars get shortened, call_id stays the same.
func TestToolNameShortening(t *testing.T) {
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
if len(longName) <= 64 {
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
}
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do it"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_long",
"type": "function",
"function": {
"name": "` + longName + `",
"arguments": "{}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
],
"tools": [
{
"type": "function",
"function": {
"name": "` + longName + `",
"description": "A tool with a very long name",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// find function_call
var funcCallItem gjson.Result
for _, item := range items {
if item.Get("type").String() == "function_call" {
funcCallItem = item
break
}
}
if !funcCallItem.Exists() {
t.Fatal("no function_call item found in output")
}
// call_id unchanged
if funcCallItem.Get("call_id").String() != "call_long" {
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
}
// name must be truncated
translatedName := funcCallItem.Get("name").String()
if translatedName == longName {
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
}
if len(translatedName) > 64 {
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
}
}
// content:"" (empty string, not null) should be treated the same as null.
func TestEmptyStringContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_empty",
"type": "function",
"function": {"name": "action", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
],
"tools": [
{
"type": "function",
"function": {
"name": "action",
"description": "An action",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: empty assistant message from content:\"\"", i)
}
}
}
// user + function_call + function_call_output
if len(items) != 3 {
t.Errorf("expected 3 input items, got %d", len(items))
}
}
// Every function_call_output must have a matching function_call by call_id.
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Multi-tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
]
},
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
],
"tools": [
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// collect call_ids from function_call items
callIDs := make(map[string]bool)
for _, item := range items {
if item.Get("type").String() == "function_call" {
callIDs[item.Get("call_id").String()] = true
}
}
for i, item := range items {
if item.Get("type").String() == "function_call_output" {
outID := item.Get("call_id").String()
if !callIDs[outID] {
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
}
}
}
// 2 calls, 2 outputs
funcCallCount := 0
funcOutputCount := 0
for _, item := range items {
switch item.Get("type").String() {
case "function_call":
funcCallCount++
case "function_call_output":
funcOutputCount++
}
}
if funcCallCount != 2 {
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
}
if funcOutputCount != 2 {
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
}
}
// Tools array should carry over to the Responses format output.
func TestToolsDefinitionTranslated(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hi"}
],
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
tools := gjson.Get(result, "tools").Array()
if len(tools) == 0 {
t.Fatal("no tools found in output")
}
found := false
for _, tool := range tools {
if tool.Get("name").String() == "search" {
found = true
break
}
}
if !found {
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
}
}

View File

@@ -41,8 +41,8 @@ type ConvertCliToOpenAIParams struct {
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response // - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &ConvertCliToOpenAIParams{ *param = &ConvertCliToOpenAIParams{
Model: modelName, Model: modelName,
@@ -55,12 +55,12 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
} }
if !bytes.HasPrefix(rawJSON, dataTag) { if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{} return [][]byte{}
} }
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
// Initialize the OpenAI SSE template. // Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{},"finish_reason":null,"native_finish_reason":null}]}`)
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
@@ -70,67 +70,67 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String() (*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int() (*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String() (*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
return []string{} return [][]byte{}
} }
// Extract and set the model version. // Extract and set the model version.
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() { if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String()) template, _ = sjson.SetBytes(template, "model", modelResult.String())
} else if cachedModel != "" { } else if cachedModel != "" {
template, _ = sjson.Set(template, "model", cachedModel) template, _ = sjson.SetBytes(template, "model", cachedModel)
} else if modelName != "" { } else if modelName != "" {
template, _ = sjson.Set(template, "model", modelName) template, _ = sjson.SetBytes(template, "model", modelName)
} }
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt) template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
// Extract and set the response ID. // Extract and set the response ID.
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID) template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() { if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
} }
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() { if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
} }
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() { if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
} }
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() { if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
} }
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() { if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
} }
} }
if dataType == "response.reasoning_summary_text.delta" { if dataType == "response.reasoning_summary_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() { if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String()) template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", deltaResult.String())
} }
} else if dataType == "response.reasoning_summary_text.done" { } else if dataType == "response.reasoning_summary_text.done" {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n") template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", "\n\n")
} else if dataType == "response.output_text.delta" { } else if dataType == "response.output_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() { if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String()) template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String())
} }
} else if dataType == "response.completed" { } else if dataType == "response.completed" {
finishReason := "stop" finishReason := "stop"
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 { if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
finishReason = "tool_calls" finishReason = "tool_calls"
} }
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason) template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason) template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.added" { } else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" { if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{} return [][]byte{}
} }
// Increment index for this new function call item. // Increment index for this new function call item.
@@ -138,9 +138,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}` functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String()) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened. // Restore original tool name if it was shortened.
name := itemResult.Get("name").String() name := itemResult.Get("name").String()
@@ -148,59 +148,59 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok { if orig, ok := rev[name]; ok {
name = orig name = orig
} }
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "") functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" { } else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String() deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}` functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" { } else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta { if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit. // Arguments were already streamed via delta events; nothing to emit.
return []string{} return [][]byte{}
} }
// Fallback: no delta events were received, emit the full arguments as a single chunk. // Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String() fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}` functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" { } else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item") itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" { if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{} return [][]byte{}
} }
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced { if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission. // Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{} return [][]byte{}
} }
// Fallback path: model skipped output_item.added, so emit complete tool call now. // Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++ (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}` functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String()) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened. // Restore original tool name if it was shortened.
name := itemResult.Get("name").String() name := itemResult.Get("name").String()
@@ -208,17 +208,17 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok { if orig, ok := rev[name]; ok {
name = orig name = orig
} }
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String()) functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else { } else {
return []string{} return [][]byte{}
} }
return []string{template} return [][]byte{template}
} }
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response. // ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
@@ -233,53 +233,53 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata // - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event // Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" { if rootResult.Get("type").String() != "response.completed" {
return "" return []byte{}
} }
unixTimestamp := time.Now().Unix() unixTimestamp := time.Now().Unix()
responseResult := rootResult.Get("response") responseResult := rootResult.Get("response")
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version. // Extract and set the model version.
if modelResult := responseResult.Get("model"); modelResult.Exists() { if modelResult := responseResult.Get("model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String()) template, _ = sjson.SetBytes(template, "model", modelResult.String())
} }
// Extract and set the creation timestamp. // Extract and set the creation timestamp.
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() { if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
template, _ = sjson.Set(template, "created", createdAtResult.Int()) template, _ = sjson.SetBytes(template, "created", createdAtResult.Int())
} else { } else {
template, _ = sjson.Set(template, "created", unixTimestamp) template, _ = sjson.SetBytes(template, "created", unixTimestamp)
} }
// Extract and set the response ID. // Extract and set the response ID.
if idResult := responseResult.Get("id"); idResult.Exists() { if idResult := responseResult.Get("id"); idResult.Exists() {
template, _ = sjson.Set(template, "id", idResult.String()) template, _ = sjson.SetBytes(template, "id", idResult.String())
} }
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := responseResult.Get("usage"); usageResult.Exists() { if usageResult := responseResult.Get("usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() { if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
} }
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() { if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
} }
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() { if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
} }
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() { if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
} }
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() { if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int()) template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
} }
} }
@@ -289,7 +289,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
outputArray := outputResult.Array() outputArray := outputResult.Array()
var contentText string var contentText string
var reasoningText string var reasoningText string
var toolCalls []string var toolCalls [][]byte
for _, outputItem := range outputArray { for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String() outputType := outputItem.Get("type").String()
@@ -319,10 +319,10 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
} }
case "function_call": case "function_call":
// Handle function call content // Handle function call content
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` functionCallTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() { if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String()) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", callIdResult.String())
} }
if nameResult := outputItem.Get("name"); nameResult.Exists() { if nameResult := outputItem.Get("name"); nameResult.Exists() {
@@ -331,11 +331,11 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if orig, ok := rev[n]; ok { if orig, ok := rev[n]; ok {
n = orig n = orig
} }
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", n)
} }
if argsResult := outputItem.Get("arguments"); argsResult.Exists() { if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String()) functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", argsResult.String())
} }
toolCalls = append(toolCalls, functionCallTemplate) toolCalls = append(toolCalls, functionCallTemplate)
@@ -344,22 +344,22 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
// Set content and reasoning content if found // Set content and reasoning content if found
if contentText != "" { if contentText != "" {
template, _ = sjson.Set(template, "choices.0.message.content", contentText) template, _ = sjson.SetBytes(template, "choices.0.message.content", contentText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
} }
if reasoningText != "" { if reasoningText != "" {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText) template, _ = sjson.SetBytes(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
} }
// Add tool calls if any // Add tool calls if any
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls", []byte(`[]`))
for _, toolCall := range toolCalls { for _, toolCall := range toolCalls {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall) template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls.-1", toolCall)
} }
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
} }
} }
@@ -367,8 +367,8 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() { if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String() status := statusResult.String()
if status == "completed" { if status == "completed" {
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop") template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop") template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
} }
} }

View File

@@ -23,7 +23,7 @@ func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *test
t.Fatalf("expected 1 chunk, got %d", len(out)) t.Fatalf("expected 1 chunk, got %d", len(out))
} }
gotModel := gjson.Get(out[0], "model").String() gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName { if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel) t.Fatalf("expected model %q, got %q", modelName, gotModel)
} }
@@ -40,8 +40,53 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
t.Fatalf("expected 1 chunk, got %d", len(out)) t.Fatalf("expected 1 chunk, got %d", len(out))
} }
gotModel := gjson.Get(out[0], "model").String() gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName { if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel) t.Fatalf("expected model %q, got %q", modelName, gotModel)
} }
} }
func TestConvertCodexResponseToOpenAI_ToolCallChunkOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls").Exists() {
t.Fatalf("expected tool_calls to exist, got %s", string(out[0]))
}
}
func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected tool call announcement chunk, got %d", len(out))
}
out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.function_call_arguments.delta","delta":"{\"query\":\"OpenAI\"}"}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls.0.function.arguments").Exists() {
t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0]))
}
}

View File

@@ -3,6 +3,7 @@ package responses
import ( import (
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -12,8 +13,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
inputResult := gjson.GetBytes(rawJSON, "input") inputResult := gjson.GetBytes(rawJSON, "input")
if inputResult.Type == gjson.String { if inputResult.Type == gjson.String {
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String()) input, _ := sjson.SetBytes([]byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`), "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", input)
} }
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true) rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
@@ -39,6 +40,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
// Convert role "system" to "developer" in input array to comply with Codex API requirements. // Convert role "system" to "developer" in input array to comply with Codex API requirements.
rawJSON = convertSystemRoleToDeveloper(rawJSON) rawJSON = convertSystemRoleToDeveloper(rawJSON)
rawJSON = normalizeCodexBuiltinTools(rawJSON)
return rawJSON return rawJSON
} }
@@ -82,3 +84,59 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
return result return result
} }
// normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the
// stable names expected by the current Codex upstream.
func normalizeCodexBuiltinTools(rawJSON []byte) []byte {
result := rawJSON
tools := gjson.GetBytes(result, "tools")
if tools.IsArray() {
toolArray := tools.Array()
for i := 0; i < len(toolArray); i++ {
typePath := fmt.Sprintf("tools.%d.type", i)
result = normalizeCodexBuiltinToolAtPath(result, typePath)
}
}
result = normalizeCodexBuiltinToolAtPath(result, "tool_choice.type")
toolChoiceTools := gjson.GetBytes(result, "tool_choice.tools")
if toolChoiceTools.IsArray() {
toolArray := toolChoiceTools.Array()
for i := 0; i < len(toolArray); i++ {
typePath := fmt.Sprintf("tool_choice.tools.%d.type", i)
result = normalizeCodexBuiltinToolAtPath(result, typePath)
}
}
return result
}
func normalizeCodexBuiltinToolAtPath(rawJSON []byte, path string) []byte {
currentType := gjson.GetBytes(rawJSON, path).String()
normalizedType := normalizeCodexBuiltinToolType(currentType)
if normalizedType == "" {
return rawJSON
}
updated, err := sjson.SetBytes(rawJSON, path, normalizedType)
if err != nil {
return rawJSON
}
log.Debugf("codex responses: normalized builtin tool type at %s from %q to %q", path, currentType, normalizedType)
return updated
}
// normalizeCodexBuiltinToolType centralizes the current known Codex Responses
// built-in tool alias compatibility. If Codex introduces more legacy aliases,
// extend this helper instead of adding path-specific rewrite logic elsewhere.
func normalizeCodexBuiltinToolType(toolType string) string {
switch toolType {
case "web_search_preview", "web_search_preview_2025_03_11":
return "web_search"
default:
return ""
}
}

View File

@@ -264,6 +264,52 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
} }
} }
func TestConvertOpenAIResponsesRequestToCodex_NormalizesWebSearchPreview(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.4-mini",
"input": "find latest OpenAI model news",
"tools": [
{"type": "web_search_preview_2025_03_11"}
],
"tool_choice": {
"type": "allowed_tools",
"tools": [
{"type": "web_search_preview"},
{"type": "web_search_preview_2025_03_11"}
]
}
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
if got := gjson.GetBytes(output, "tools.0.type").String(); got != "web_search" {
t.Fatalf("tools.0.type = %q, want %q: %s", got, "web_search", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "allowed_tools" {
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "allowed_tools", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.tools.0.type").String(); got != "web_search" {
t.Fatalf("tool_choice.tools.0.type = %q, want %q: %s", got, "web_search", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.tools.1.type").String(); got != "web_search" {
t.Fatalf("tool_choice.tools.1.type = %q, want %q: %s", got, "web_search", string(output))
}
}
func TestConvertOpenAIResponsesRequestToCodex_NormalizesTopLevelToolChoicePreviewAlias(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.4-mini",
"input": "find latest OpenAI model news",
"tool_choice": {"type": "web_search_preview_2025_03_11"}
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "web_search" {
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "web_search", string(output))
}
}
func TestUserFieldDeletion(t *testing.T) { func TestUserFieldDeletion(t *testing.T) {
inputJSON := []byte(`{ inputJSON := []byte(`{
"model": "gpt-5.2", "model": "gpt-5.2",

View File

@@ -3,7 +3,6 @@ package responses
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -11,23 +10,25 @@ import (
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks // ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*). // to OpenAI Responses SSE events (response.*).
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string { func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) { if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
out := fmt.Sprintf("data: %s", string(rawJSON)) out := make([]byte, 0, len(rawJSON)+len("data: "))
return []string{out} out = append(out, []byte("data: ")...)
out = append(out, rawJSON...)
return [][]byte{out}
} }
return []string{string(rawJSON)} return [][]byte{rawJSON}
} }
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON // ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response. // from a non-streaming OpenAI Chat Completions response.
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string { func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON) rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event // Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" { if rootResult.Get("type").String() != "response.completed" {
return "" return []byte{}
} }
responseResult := rootResult.Get("response") responseResult := rootResult.Get("response")
return responseResult.Raw return []byte(responseResult.Raw)
} }

View File

@@ -0,0 +1,67 @@
package common
import (
"strconv"
"github.com/tidwall/sjson"
)
func WrapGeminiCLIResponse(response []byte) []byte {
out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response)
if err != nil {
return response
}
return out
}
func GeminiTokenCountJSON(count int64) []byte {
out := make([]byte, 0, 96)
out = append(out, `{"totalTokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `,"promptTokensDetails":[{"modality":"TEXT","tokenCount":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `}]}`...)
return out
}
func ClaudeInputTokensJSON(count int64) []byte {
out := make([]byte, 0, 32)
out = append(out, `{"input_tokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, '}')
return out
}
func SSEEventData(event string, payload []byte) []byte {
out := make([]byte, 0, len(event)+len(payload)+14)
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
return out
}
func AppendSSEEventString(out []byte, event, payload string, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}
func AppendSSEEventBytes(out []byte, event string, payload []byte, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}

View File

@@ -6,10 +6,10 @@
package claude package claude
import ( import (
"bytes"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -36,33 +36,32 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// - []byte: The transformed request data in Gemini CLI API format // - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON // Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}` out := []byte(`{"model":"","request":{"contents":[]}}`)
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.SetBytes(out, "model", modelName)
// system instruction // system instruction
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() { if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}` systemInstruction := []byte(`{"role":"user","parts":[]}`)
hasSystemParts := false hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool { systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" { if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text") textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String { if textResult.Type == gjson.String {
part := `{"text":""}` part := []byte(`{"text":""}`)
part, _ = sjson.Set(part, "text", textResult.String()) part, _ = sjson.SetBytes(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part) systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
hasSystemParts = true hasSystemParts = true
} }
} }
return true return true
}) })
if hasSystemParts { if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction) out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction)
} }
} else if systemResult.Type == gjson.String { } else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String()) out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String())
} }
// contents // contents
@@ -77,28 +76,28 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
role = "model" role = "model"
} }
contentJSON := `{"role":"","parts":[]}` contentJSON := []byte(`{"role":"","parts":[]}`)
contentJSON, _ = sjson.Set(contentJSON, "role", role) contentJSON, _ = sjson.SetBytes(contentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool { contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() { switch contentResult.Get("type").String() {
case "text": case "text":
part := `{"text":""}` part := []byte(`{"text":""}`)
part, _ = sjson.Set(part, "text", contentResult.Get("text").String()) part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "tool_use": case "tool_use":
functionName := contentResult.Get("name").String() functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs) argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) { if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}` part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`)
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature) part, _ = sjson.SetBytes(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName) part, _ = sjson.SetBytes(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs) part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs))
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
} }
case "tool_result": case "tool_result":
@@ -112,10 +111,10 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
} }
responseData := contentResult.Get("content").Raw responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}` part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
part, _ = sjson.Set(part, "functionResponse.name", funcName) part, _ = sjson.SetBytes(part, "functionResponse.name", util.SanitizeFunctionName(funcName))
part, _ = sjson.Set(part, "functionResponse.response.result", responseData) part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "image": case "image":
source := contentResult.Get("source") source := contentResult.Get("source")
@@ -123,21 +122,21 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
mimeType := source.Get("media_type").String() mimeType := source.Get("media_type").String()
data := source.Get("data").String() data := source.Get("data").String()
if mimeType != "" && data != "" { if mimeType != "" && data != "" {
part := `{"inlineData":{"mime_type":"","data":""}}` part := []byte(`{"inlineData":{"mime_type":"","data":""}}`)
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType) part, _ = sjson.SetBytes(part, "inlineData.mime_type", mimeType)
part, _ = sjson.Set(part, "inlineData.data", data) part, _ = sjson.SetBytes(part, "inlineData.data", data)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
} }
} }
} }
return true return true
}) })
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON) out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
part := `{"text":""}` part := []byte(`{"text":""}`)
part, _ = sjson.Set(part, "text", contentsResult.String()) part, _ = sjson.SetBytes(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON) out, _ = sjson.SetRawBytes(out, "request.contents.-1", contentJSON)
} }
return true return true
}) })
@@ -149,26 +148,28 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
toolsResult.ForEach(func(_, toolResult gjson.Result) bool { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.DeleteBytes(tool, "strict")
tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.DeleteBytes(tool, "input_examples")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.DeleteBytes(tool, "type")
tool, _ = sjson.Delete(tool, "defer_loading") tool, _ = sjson.DeleteBytes(tool, "cache_control")
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() { tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
if !hasTools { if !hasTools {
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`) out, _ = sjson.SetRawBytes(out, "request.tools", []byte(`[{"functionDeclarations":[]}]`))
hasTools = true hasTools = true
} }
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool) out, _ = sjson.SetRawBytes(out, "request.tools.0.functionDeclarations.-1", tool)
} }
} }
return true return true
}) })
if !hasTools { if !hasTools {
out, _ = sjson.Delete(out, "request.tools") out, _ = sjson.DeleteBytes(out, "request.tools")
} }
} }
@@ -186,15 +187,15 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
switch toolChoiceType { switch toolChoiceType {
case "auto": case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none": case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any": case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool": case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" { if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
} }
} }
} }
@@ -206,8 +207,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
case "enabled": case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int()) budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
} }
case "adaptive", "auto": case "adaptive", "auto":
// For adaptive thinking: // For adaptive thinking:
@@ -219,25 +220,23 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
effort = strings.ToLower(strings.TrimSpace(v.String())) effort = strings.ToLower(strings.TrimSpace(v.String()))
} }
if effort != "" { if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else { } else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
} }
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
} }
} }
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
} }
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
} }
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number { if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num) out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
} }
outBytes := []byte(out) out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings") return out
return outBytes
} }

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -27,6 +28,9 @@ type Params struct {
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
ResponseIndex int // Index counter for content blocks in the streaming response ResponseIndex int // Index counter for content blocks in the streaming response
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Reverse map: sanitized Gemini function name → original Claude tool name.
ToolNameMap map[string]string
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -47,47 +51,47 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls // - param: A pointer to a parameter object for maintaining state between calls
// //
// Returns: // Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response // - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil { if *param == nil {
*param = &Params{ *param = &Params{
HasFirstResponse: false, HasFirstResponse: false,
ResponseType: 0, ResponseType: 0,
ResponseIndex: 0, ResponseIndex: 0,
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
} }
} }
if bytes.Equal(rawJSON, []byte("[DONE]")) { if bytes.Equal(rawJSON, []byte("[DONE]")) {
// Only send message_stop if we have actually output content // Only send message_stop if we have actually output content
if (*param).(*Params).HasContent { if (*param).(*Params).HasContent {
return []string{ return [][]byte{translatorcommon.AppendSSEEventString(nil, "message_stop", `{"type":"message_stop"}`, 3)}
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
} }
return []string{} return [][]byte{}
} }
// Track whether tools are being used in this response chunk // Track whether tools are being used in this response chunk
usedTool := false usedTool := false
output := "" output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event // Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session // This is only sent for the very first response chunk to establish the streaming session
if !(*param).(*Params).HasFirstResponse { if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification // Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization // This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` messageStartTemplate := []byte(`{"type":"message_start","message":{"id":"msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet-20241022","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`)
// Override default values with actual response metadata if available from the Gemini CLI response // Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
} }
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
} }
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) appendEvent("message_start", string(messageStartTemplate))
(*param).(*Params).HasFirstResponse = true (*param).(*Params).HasFirstResponse = true
} }
@@ -110,9 +114,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if partResult.Get("thought").Bool() { if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state // Continue existing thinking block if already in thinking state
if (*param).(*Params).ResponseType == 2 { if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n" data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String()) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true (*param).(*Params).HasContent = true
} else { } else {
// Transition from another state to thinking // Transition from another state to thinking
@@ -123,19 +126,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n" // output = output + "\n\n\n"
} }
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseIndex++
} }
// Start a new thinking content block // Start a new thinking content block
output = output + "event: content_block_start\n" appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex) data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex)), "delta.thinking", partTextResult.String())
output = output + "\n\n\n" appendEvent("content_block_delta", string(data))
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 2 // Set state to thinking (*param).(*Params).ResponseType = 2 // Set state to thinking
(*param).(*Params).HasContent = true (*param).(*Params).HasContent = true
} }
@@ -143,9 +141,8 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Process regular text content (user-visible output) // Process regular text content (user-visible output)
// Continue existing text block if already in content state // Continue existing text block if already in content state
if (*param).(*Params).ResponseType == 1 { if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n" data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String()) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true (*param).(*Params).HasContent = true
} else { } else {
// Transition from another state to text content // Transition from another state to text content
@@ -156,19 +153,14 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex) // output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n" // output = output + "\n\n\n"
} }
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseIndex++
} }
// Start a new text content block // Start a new text content block
output = output + "event: content_block_start\n" appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex) data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex)), "delta.text", partTextResult.String())
output = output + "\n\n\n" appendEvent("content_block_delta", string(data))
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 1 // Set state to content (*param).(*Params).ResponseType = 1 // Set state to content
(*param).(*Params).HasContent = true (*param).(*Params).HasContent = true
} }
@@ -177,14 +169,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Handle function/tool calls from the AI model // Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility // This processes tool usage requests and formats them for Claude Code API compatibility
usedTool = true usedTool = true
fcName := functionCallResult.Get("name").String() fcName := util.RestoreSanitizedToolName((*param).(*Params).ToolNameMap, functionCallResult.Get("name").String())
// Handle state transitions when switching to function calls // Handle state transitions when switching to function calls
// Close any existing function call block first // Close any existing function call block first
if (*param).(*Params).ResponseType == 3 { if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0 (*param).(*Params).ResponseType = 0
} }
@@ -198,26 +188,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Close any other existing content block // Close any other existing content block
if (*param).(*Params).ResponseType != 0 { if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++ (*param).(*Params).ResponseIndex++
} }
// Start a new tool use content block // Start a new tool use content block
// This creates the structure for a function call in Claude Code format // This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details // Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))) data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName) data, _ = sjson.SetBytes(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data) appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n" data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) appendEvent("content_block_delta", string(data))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} }
(*param).(*Params).ResponseType = 3 (*param).(*Params).ResponseType = 3
(*param).(*Params).HasContent = true (*param).(*Params).HasContent = true
@@ -232,34 +217,28 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Only send final events if we have actually output content // Only send final events if we have actually output content
if (*param).(*Params).HasContent { if (*param).(*Params).HasContent {
// Close the final content block // Close the final content block
output = output + "event: content_block_stop\n" appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
// Send the final message delta with usage information and stop reason
output = output + "event: message_delta\n"
output = output + `data: `
// Create the message delta template with appropriate stop reason // Create the message delta template with appropriate stop reason
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` template := []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
// Set tool_use stop reason if tools were used in this response // Set tool_use stop reason if tools were used in this response
if usedTool { if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
} else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" { } else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` template = []byte(`{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
} }
// Include thinking tokens in output token count if present // Include thinking tokens in output token count if present
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) template, _ = sjson.SetBytes(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) template, _ = sjson.SetBytes(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n" appendEvent("message_delta", string(template))
} }
} }
} }
return []string{output} return [][]byte{output}
} }
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. // ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
@@ -271,21 +250,21 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// - param: A pointer to a parameter object for the conversion. // - param: A pointer to a parameter object for the conversion.
// //
// Returns: // Returns:
// - string: A Claude-compatible JSON response. // - []byte: A Claude-compatible JSON response.
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
_ = requestRawJSON _ = requestRawJSON
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String()) out, _ = sjson.SetBytes(out, "id", root.Get("response.responseId").String())
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String()) out, _ = sjson.SetBytes(out, "model", root.Get("response.modelVersion").String())
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int() inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int() outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
parts := root.Get("response.candidates.0.content.parts") parts := root.Get("response.candidates.0.content.parts")
textBuilder := strings.Builder{} textBuilder := strings.Builder{}
@@ -297,9 +276,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
block := `{"type":"text","text":""}` block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.Set(block, "text", textBuilder.String()) block, _ = sjson.SetBytes(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
textBuilder.Reset() textBuilder.Reset()
} }
@@ -307,9 +286,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
block := `{"type":"thinking","thinking":""}` block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String()) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block) out, _ = sjson.SetRawBytes(out, "content.-1", block)
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
@@ -331,17 +310,17 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
flushText() flushText()
hasToolCall = true hasToolCall = true
name := functionCall.Get("name").String() name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
toolIDCounter++ toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}" inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw inputRaw = args.Raw
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw) toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRaw(out, "content.-1", toolBlock) out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
continue continue
} }
} }
@@ -365,15 +344,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
} }
} }
} }
out, _ = sjson.Set(out, "stop_reason", stopReason) out, _ = sjson.SetBytes(out, "stop_reason", stopReason)
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() { if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage") out, _ = sjson.DeleteBytes(out, "usage")
} }
return out return out
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"input_tokens":%d}`, count) return translatorcommon.ClaudeInputTokensJSON(count)
} }

View File

@@ -7,6 +7,7 @@ package gemini
import ( import (
"fmt" "fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -33,23 +34,23 @@ import (
// - []byte: The transformed request data in Gemini API format // - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte { func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON rawJSON := inputRawJSON
template := "" template := []byte(`{"project":"","request":{},"model":""}`)
template = `{"project":"","request":{},"model":""}` template, _ = sjson.SetRawBytes(template, "request", rawJSON)
template, _ = sjson.SetRaw(template, "request", string(rawJSON)) template, _ = sjson.SetBytes(template, "model", gjson.GetBytes(template, "request.model").String())
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.DeleteBytes(template, "request.model")
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template) templateStr, errFixCLIToolResponse := fixCLIToolResponse(string(template))
if errFixCLIToolResponse != nil { if errFixCLIToolResponse != nil {
return []byte{} return []byte{}
} }
template = []byte(templateStr)
systemInstructionResult := gjson.Get(template, "request.system_instruction") systemInstructionResult := gjson.GetBytes(template, "request.system_instruction")
if systemInstructionResult.Exists() { if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) template, _ = sjson.SetRawBytes(template, "request.systemInstruction", []byte(systemInstructionResult.Raw))
template, _ = sjson.Delete(template, "request.system_instruction") template, _ = sjson.DeleteBytes(template, "request.system_instruction")
} }
rawJSON = []byte(template) rawJSON = template
// Normalize roles in request.contents: default to valid values if missing/invalid // Normalize roles in request.contents: default to valid values if missing/invalid
contents := gjson.GetBytes(rawJSON, "request.contents") contents := gjson.GetBytes(rawJSON, "request.contents")
@@ -110,12 +111,41 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
return true return true
}) })
// Filter out contents with empty parts to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
filteredContents := []byte(`[]`)
hasFiltered := false
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool {
parts := content.Get("parts")
if !parts.IsArray() || len(parts.Array()) == 0 {
hasFiltered = true
return true
}
filteredContents, _ = sjson.SetRawBytes(filteredContents, "-1", []byte(content.Raw))
return true
})
if hasFiltered {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", filteredContents)
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings") return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
} }
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ResponsesNeeded int ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
// falling back to fallbackName if the original is empty.
func backfillFunctionResponseName(raw string, fallbackName string) string {
name := gjson.Get(raw, "functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
rawBytes, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
raw = string(rawBytes)
}
return raw
} }
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping. // fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -142,7 +172,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}` contentsWrapper := []byte(`{"contents":[]}`)
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -165,31 +195,28 @@ func fixCLIToolResponse(input string) (string, error) {
if len(responsePartsInThisContent) > 0 { if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...) collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied // Check if pending groups can be satisfied (FIFO: oldest group first)
for i := len(pendingGroups) - 1; i >= 0; i-- { for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[i] group := pendingGroups[0]
if len(collectedResponses) >= group.ResponsesNeeded { pendingGroups = pendingGroups[1:]
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Take the needed responses for this group
functionResponseContent := `{"parts":[],"role":"function"}` groupResponses := collectedResponses[:group.ResponsesNeeded]
for _, response := range groupResponses { collectedResponses = collectedResponses[group.ResponsesNeeded:]
if !response.IsObject() {
log.Warnf("failed to parse function response") // Create merged function response content
continue functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
} for ri, response := range groupResponses {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
} }
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
} }
} }
@@ -198,25 +225,26 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
functionCallsCount := 0 var callNames []string
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsCount++ callNames = append(callNames, part.Get("functionCall.name").String())
} }
return true return true
}) })
if functionCallsCount > 0 { if len(callNames) > 0 {
// Add the model content // Add the model content
if !value.IsObject() { if !value.IsObject() {
log.Warnf("failed to parse model content") log.Warnf("failed to parse model content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ResponsesNeeded: functionCallsCount, ResponsesNeeded: len(callNames),
CallNames: callNames,
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
@@ -225,7 +253,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content") log.Warnf("failed to parse content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
@@ -233,7 +261,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content") log.Warnf("failed to parse content")
return true return true
} }
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
} }
return true return true
@@ -245,24 +273,25 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}` functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for _, response := range groupResponses { for ri, response := range groupResponses {
if !response.IsObject() { if !response.IsObject() {
log.Warnf("failed to parse function response") log.Warnf("failed to parse function response")
continue continue
} }
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw) raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(raw))
} }
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 { if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent) contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result := []byte(input)
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw) result, _ = sjson.SetRawBytes(result, "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
return result, nil return string(result), nil
} }

View File

@@ -8,8 +8,8 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -29,8 +29,8 @@ import (
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - []string: The transformed request data in Gemini API format // - [][]byte: The transformed request data in Gemini API format
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string { func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) { if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:]) rawJSON = bytes.TrimSpace(rawJSON[5:])
} }
@@ -43,22 +43,22 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
chunk = []byte(responseResult.Raw) chunk = []byte(responseResult.Raw)
} }
} else { } else {
chunkTemplate := "[]" chunkTemplate := []byte(`[]`)
responseResult := gjson.ParseBytes(chunk) responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() { if responseResult.IsArray() {
responseResultItems := responseResult.Array() responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ { for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i] responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() { if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
} }
} }
} }
chunk = []byte(chunkTemplate) chunk = chunkTemplate
} }
return []string{string(chunk)} return [][]byte{chunk}
} }
return []string{} return [][]byte{}
} }
// ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response. // ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
@@ -72,15 +72,15 @@ func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalReq
// - param: A pointer to a parameter object for the conversion (unused in current implementation) // - param: A pointer to a parameter object for the conversion (unused in current implementation)
// //
// Returns: // Returns:
// - string: A Gemini-compatible JSON response containing the response data // - []byte: A Gemini-compatible JSON response containing the response data
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response") responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() { if responseResult.Exists() {
return responseResult.Raw return []byte(responseResult.Raw)
} }
return string(rawJSON) return rawJSON
} }
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) []byte {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return translatorcommon.GeminiTokenCountJSON(count)
} }

View File

@@ -251,7 +251,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
continue continue
} }
fid := tc.Get("id").String() fid := tc.Get("id").String()
fname := tc.Get("function.name").String() fname := util.SanitizeFunctionName(tc.Get("function.name").String())
fargs := tc.Get("function.arguments").String() fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
@@ -268,7 +268,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
pp := 0 pp := 0
for _, fid := range fIDs { for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok { if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name) toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
resp := toolResponses[fid] resp := toolResponses[fid]
if resp == "" { if resp == "" {
resp = "{}" resp = "{}"
@@ -299,43 +299,44 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if t.Get("type").String() == "function" { if t.Get("type").String() == "function" {
fn := t.Get("function") fn := t.Get("function")
if fn.Exists() && fn.IsObject() { if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw fnRaw := []byte(fn.Raw)
if fn.Get("parameters").Exists() { if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema") renamed, errRename := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema")
if errRename != nil { if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename) log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
} else { } else {
fnRaw = renamed fnRaw = []byte(renamed)
} }
} else { } else {
var errSet error var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") fnRaw, errSet = sjson.SetBytes(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) fnRaw, errSet = sjson.SetRawBytes(fnRaw, "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
} }
fnRaw, _ = sjson.Delete(fnRaw, "strict") fnRaw, _ = sjson.SetBytes(fnRaw, "name", util.SanitizeFunctionName(fn.Get("name").String()))
fnRaw, _ = sjson.DeleteBytes(fnRaw, "strict")
if !hasFunction { if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
} }
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw)) tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", fnRaw)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue continue

Some files were not shown because too many files have changed in this diff Show More