Compare commits

..

153 Commits

Author SHA1 Message Date
Luis Pater
f8d1bc06ea Merge pull request #469 from router-for-me/plus
v6.9.5
2026-03-29 12:40:26 +08:00
Luis Pater
d5930f4e44 Merge branch 'main' into plus 2026-03-29 12:40:17 +08:00
Luis Pater
9b7d7021af docs(readme): update LingtrueAPI link in all README translations 2026-03-29 12:30:24 +08:00
Luis Pater
e41c22ef44 docs(readme): add LingtrueAPI sponsorship details to all README translations 2026-03-29 12:23:37 +08:00
Luis Pater
55271403fb Merge pull request #2374 from VooDisss/codex-cache-clean
fix(codex): restore prompt cache continuity for Codex requests
2026-03-28 21:16:51 +08:00
Luis Pater
36fba66619 Merge pull request #2371 from RaviTharuma/docs/provider-specific-routes
docs: clarify provider-specific routing for aliased models
2026-03-28 21:11:29 +08:00
Luis Pater
b9b127a7ea Merge pull request #2347 from edlsh/fix/codex-strip-stream-options
fix(codex): strip stream_options from Responses API requests
2026-03-28 21:03:01 +08:00
Luis Pater
2741e7b7b3 Merge pull request #2346 from pjpjq/codex/fix-codex-capacity-retry
fix(codex): Treat Codex capacity errors as retryable
2026-03-28 21:00:50 +08:00
Luis Pater
1767a56d4f Merge pull request #2343 from kongkk233/fix/proxy-transport-defaults
Preserve default transport settings for proxy clients
2026-03-28 20:58:24 +08:00
Luis Pater
779e6c2d2f Merge pull request #2231 from 7RPH/fix/responses-stream-multi-tool-calls
fix: preserve separate streamed tool calls in Responses API
2026-03-28 20:53:19 +08:00
Luis Pater
73c831747b Merge pull request #2133 from DragonFSKY/fix/2061-stale-modelstates
fix(auth): prevent stale runtime state inheritance from disabled auth entries
2026-03-28 20:50:57 +08:00
Luis Pater
b8b89f34f4 Merge pull request #442 from LuxVTZ/feat/gitlab-duo-panel-parity
Improve GitLab Duo gateway compatibility\n\nRestore internal/runtime/executor/claude_executor.go to main during merge.
2026-03-28 05:06:41 +08:00
Luis Pater
1fa094dac6 Merge pull request #461 from MrHuangJser/main
feat(cursor): Full Cursor provider with H2 streaming, MCP tools, multi-turn & multi-account
2026-03-28 05:01:27 +08:00
Luis Pater
f55754621f Merge pull request #464 from router-for-me/plus
v6.9.4
2026-03-28 04:51:27 +08:00
Luis Pater
ac26e7db43 Merge branch 'main' into plus 2026-03-28 04:51:18 +08:00
Luis Pater
10b824fcac fix(security): validate auth file names to prevent unsafe input 2026-03-28 04:48:23 +08:00
VooDisss
e5d3541b5a refactor(codex): remove stale affinity cleanup leftovers
Drop the last affinity-related executor artifacts so the PR stays focused on the minimal Codex continuity fix set: stable prompt cache identity, stable session_id, and the executor-only behavior that was validated to restore cache reads.
2026-03-27 20:40:26 +02:00
VooDisss
79755e76ea refactor(pr): remove forbidden translator changes
Drop the chat-completions translator edits from this PR so the branch complies with the repository policy that forbids pull-request changes under internal/translator. The remaining PR stays focused on the executor-level Codex continuity fix that was validated to restore cache reuse.
2026-03-27 19:34:13 +02:00
VooDisss
35f158d526 refactor(pr): narrow Codex cache fix scope
Remove the experimental auth-affinity routing changes from this PR so it stays focused on the validated Codex continuity fix. This keeps the prompt-cache repair while avoiding unrelated routing-policy concerns such as provider/model affinity scope, lifecycle cleanup, and hard-pin fallback semantics.
2026-03-27 19:06:34 +02:00
VooDisss
6962e09dd9 fix(auth): scope affinity by provider
Keep sticky auth affinity limited to matching providers and stop persisting execution-session IDs as long-lived affinity keys so provider switching and normal streaming traffic do not create incorrect pins or stale affinity state.
2026-03-27 18:52:58 +02:00
VooDisss
4c4cbd44da fix(auth): avoid leaking or over-persisting affinity keys
Stop using one-shot idempotency keys as long-lived auth-affinity identifiers and remove raw affinity-key values from debug logs so sticky routing keeps its continuity benefits without creating avoidable memory growth or credential exposure risks.
2026-03-27 18:34:51 +02:00
VooDisss
26eca8b6ba fix(codex): preserve continuity and safe affinity fallback
Restore Claude continuity after the continuity refactor, keep auth-affinity keys out of upstream Codex session identifiers, and only persist affinity after successful execution so retries can still rotate to healthy credentials when the first auth fails.
2026-03-27 18:27:33 +02:00
VooDisss
62b17f40a1 refactor(codex): align continuity helpers with review feedback
Align websocket continuity resolution with the HTTP Codex path, make auth-affinity principal keys use a stable string representation, and extract small helpers that remove duplicated continuity and affinity logic without changing the validated cache-hit behavior.
2026-03-27 18:11:57 +02:00
VooDisss
511b8a992e fix(codex): restore prompt cache continuity for Codex requests
Prompt caching on Codex was not reliably reusable through the proxy because repeated chat-completions requests could reach the upstream without the same continuity envelope. In practice this showed up most clearly with OpenCode, where cache reads worked in the reference client but not through CLIProxyAPI, although the root cause is broader than OpenCode itself.

The proxy was breaking continuity in several ways: executor-layer Codex request preparation stripped prompt_cache_retention, chat-completions translation did not preserve that field, continuity headers used a different shape than the working client behavior, and OpenAI-style Codex requests could be sent without a stable prompt_cache_key. When that happened, session_id fell back to a fresh random value per request, so upstream Codex treated repeated requests as unrelated turns instead of as part of the same cacheable context.

This change fixes that by preserving caller-provided prompt_cache_retention on Codex execution paths, preserving prompt_cache_retention when translating OpenAI chat-completions requests to Codex, aligning Codex continuity headers to session_id, and introducing an explicit Codex continuity policy that derives a stable continuity key from the best available signal. The resolution order prefers an explicit prompt_cache_key, then execution session metadata, then an explicit idempotency key, then stable request-affinity metadata, then a stable client-principal hash, and finally a stable auth-ID hash when no better continuity signal exists.

The same continuity key is applied to both prompt_cache_key in the request body and session_id in the request headers so repeated requests reuse the same upstream cache/session identity. The auth manager also keeps auth selection sticky for repeated request sequences, preventing otherwise-equivalent Codex requests from drifting across different upstream auth contexts and accidentally breaking cache reuse.

To keep the implementation maintainable, the continuity resolution and diagnostics are centralized in a dedicated Codex continuity helper instead of being scattered across executor flow code. Regression coverage now verifies retention preservation, continuity-key precedence, stable auth-ID fallback, websocket parity, translator preservation, and auth-affinity behavior. Manual validation confirmed prompt cache reads now occur through CLIProxyAPI when using Codex via OpenCode, and the fix should also benefit other clients that rely on stable repeated Codex request continuity.
2026-03-27 17:49:29 +02:00
Luis Pater
7dccc7ba2f docs(readme): remove redundant whitespace in BmoPlus sponsorship section of Chinese README 2026-03-27 20:52:14 +08:00
Luis Pater
70c90687fd docs(readme): fix formatting in BmoPlus sponsorship section of Chinese README 2026-03-27 20:49:43 +08:00
Luis Pater
8144ffd5c8 Merge pull request #2370 from B3o/add-bmoplus-sponsor
docs: add BmoPlus sponsorship banners to READMEs
2026-03-27 20:48:22 +08:00
Ravi Tharuma
0ab977c236 docs: clarify provider path limitations 2026-03-27 11:13:08 +01:00
Ravi Tharuma
224f0de353 docs: neutralize provider-specific path wording 2026-03-27 11:11:06 +01:00
B3o
6b45d311ec add BmoPlus sponsorship banners to READMEs 2026-03-27 18:01:35 +08:00
Ravi Tharuma
d54de441d3 docs: clarify provider-specific routing for aliased models 2026-03-27 10:53:09 +01:00
MrHuangJser
7386a70724 feat(cursor): auto-identify accounts from JWT sub for multi-account support
Previously Cursor required a manual ?label=xxx parameter to distinguish
accounts (unlike Codex which auto-generates filenames from JWT claims).

Cursor JWTs contain a "sub" claim (e.g. "auth0|user_XXXX") that uniquely
identifies each account. Now we:

- Add ParseJWTSub() + SubToShortHash() to extract and hash the sub claim
- Refactor GetTokenExpiry() to share the new decodeJWTPayload() helper
- Update CredentialFileName(label, subHash) to auto-generate filenames
  from the sub hash when no explicit label is provided
  (e.g. "cursor.8f202e67.json" instead of always "cursor.json")
- Add DisplayLabel() for human-readable account identification
- Store "sub" in metadata for observability
- Update both management API handler and SDK authenticator

Same account always produces the same filename (deterministic), different
accounts get different files. Explicit ?label= still takes priority.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 17:40:02 +08:00
白金
1821bf7051 docs: add BmoPlus sponsorship banners to READMEs 2026-03-27 17:39:29 +08:00
Luis Pater
d42b5d4e78 docs(readme): update QQ group information in Chinese README 2026-03-27 11:46:21 +08:00
MrHuangJser
1b7447b682 feat(cursor): implement StatusError for conductor cooldown integration
Cursor executor errors were plain fmt.Errorf — the conductor couldn't
extract HTTP status codes, so exhausted accounts never entered cooldown.

Changes:
- Add ConnectError struct to proto/connect.go: ParseConnectEndStream now
  returns *ConnectError with Code/Message fields for precise matching
- Add cursorStatusErr implementing StatusError + RetryAfter interfaces
- Add classifyCursorError() with two-layer classification:
  Layer 1: exact match on ConnectError.Code (gRPC standard codes)
    resource_exhausted → 429, unauthenticated → 401,
    permission_denied → 403, unavailable → 503, internal → 500
  Layer 2: fuzzy string match for H2 errors (RST_STREAM → 502)
- Log all ConnectError code/message pairs for observing real server
  error codes (we have no samples yet)
- Wrap Execute and ExecuteStream error returns with classifyCursorError

Now the conductor properly marks Cursor auths as cooldown on quota errors,
enabling exponential backoff and round-robin failover.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 11:42:22 +08:00
MrHuangJser
40dee4453a feat(cursor): auto-migrate sessions to healthy account on quota exhaustion
When a Cursor account's quota is exhausted, sessions bound to it can now
seamlessly continue on a different account:

Layer 1 — Checkpoint decoupling:
  Key checkpoints by conversationId (not authID:conversationId). Store
  authID inside savedCheckpoint. On lookup, if auth changed, discard the
  stale checkpoint and flatten conversation history into userText.

Layer 2 — Cross-account session cleanup:
  When a request arrives for a conversation whose session belongs to a
  different (now-exhausted) auth, close the old H2 stream and remove
  the stale session to free resources.

Layer 3 — H2Stream.Err() exposure:
  New Err() method on H2Stream so callers can inspect RST_STREAM,
  GOAWAY, or other stream-level errors after closure.

Layer 4 — processH2SessionFrames error propagation:
  Returns error instead of bare return. Connect EndStream errors (quota,
  rate limit) are now propagated instead of being logged and swallowed.

Layer 5 — Pre-response transparent retry:
  If the stream fails before any data is sent to the client, return an
  error to the conductor so it retries with a different auth — fully
  transparent to the client.

Layer 6 — Post-response error logging:
  If the stream fails after data was already sent, log a warning. The
  conductor's existing cooldown mechanism ensures the next request routes
  to a healthy account.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 10:50:32 +08:00
MrHuangJser
8902e1cccb style(cursor): replace fmt.Print* with log package for consistent logging
Address Gemini Code Assist review feedback: use logrus log package
instead of fmt.Printf/Println in Cursor auth handlers and CLI for
unified log formatting and level control.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 17:03:32 +08:00
黄姜恒
de5fe71478 feat(cursor): multi-account routing with round-robin and session isolation
- Add cursor/filename.go for multi-account credential file naming
- Include auth.ID in session and checkpoint keys for per-account isolation
- Record authID in cursorSession, validate on resume to prevent cross-account access
- Management API /cursor-auth-url supports ?label= for creating named accounts
- Leverages existing conductor round-robin + failover framework

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:27:49 +08:00
黄姜恒
dcfbec2990 feat(cursor): add management API for Cursor OAuth authentication
- Add RequestCursorToken handler with PKCE + polling flow
- Register /v0/management/cursor-auth-url route
- Returns login URL + state for browser auth, polls in background
- Saves cursor.json with access/refresh tokens on success

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 11:10:07 +08:00
黄姜恒
c95620f90e feat(cursor): conversation checkpoint + session_id for multi-turn context
- Capture conversation_checkpoint_update from Cursor server (was ignored)
- Store checkpoint per conversationId, replay as conversation_state on next request
- Use protowire to embed raw checkpoint bytes directly (no deserialization)
- Extract session_id from Claude Code metadata for stable conversationId across resume
- Flatten conversation history into userText as fallback when no checkpoint available
- Use conversationId as session key for reliable tool call resume
- Add checkpoint TTL cleanup (30min)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 10:51:47 +08:00
edlsh
754f3bcbc3 fix(codex): strip stream_options from Responses API requests
The Codex/OpenAI Responses API does not support the stream_options
parameter. When clients (e.g. Amp CLI) include stream_options in their
requests, CLIProxyAPI forwards it as-is, causing a 400 error:

  {"detail":"Unsupported parameter: stream_options"}

Strip stream_options alongside the other unsupported parameters
(previous_response_id, prompt_cache_retention, safety_identifier)
in Execute, ExecuteStream, and CountTokens.
2026-03-25 11:58:36 -04:00
pjpj
36973d4a6f Handle Codex capacity errors as retryable 2026-03-25 23:25:31 +08:00
黄姜恒
9613f0b3f9 feat(cursor): deterministic conversation_id from Claude Code session cch
Extract the cch hash from Claude Code's billing header in the system
prompt (x-anthropic-billing-header: ...cch=XXXXX;) and use it to derive
a deterministic conversation_id instead of generating a random UUID.

Same Claude Code session → same cch → same conversation_id → Cursor
server can reuse conversation state across multiple turns, preserving
tool call results and other context without re-encoding history.

Also cleans up temporary debug logging from previous iterations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:29:49 +08:00
黄姜恒
274f29e26b fix(cursor): improve session key uniqueness for multi-session safety
Include system prompt prefix (first 200 chars) in session key derivation.
Claude Code sessions have unique system prompts containing cwd, session_id,
file paths, etc., making collisions between concurrent sessions from the
same user virtually impossible.

Session key now = SHA256(apiKey + model + systemPrompt[:200] + firstUserMsg)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:24:37 +08:00
黄姜恒
c8e79c3787 fix(cursor): prevent session key collision across users
Include client API key in session key derivation to prevent different
users sharing the same proxy from accidentally resuming each other's
H2 streams when they send identical first messages with the same model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:19:11 +08:00
黄姜恒
8afef43887 fix(cursor): preserve tool call context in multi-turn conversations
When an assistant message appears after tool results without a pending
user message, append it to the last turn's assistant text instead of
dropping it. Also add bakeToolResultsIntoTurns() to merge tool results
into turn context when no active H2 session exists for resume, ensuring
the model sees the full tool interaction history in follow-up requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:15:24 +08:00
黄姜恒
c1083cbfc6 fix(cursor): MCP tool call resume, H2 flow control, and token usage
- Rewrite tool call mechanism from interrupt-resume to inline-wait mode:
  processH2SessionFrames no longer exits on mcpArgs; instead blocks on
  toolResultCh while continuing to handle KV/heartbeat messages, then
  sends MCP result and continues processing text in the same goroutine.
  Fixes the issue where server stopped generating text after resume.

- Add switchable output channel (outMu/currentOut) so first HTTP response
  closes after tool_calls+[DONE], and resumed text goes to a new channel
  returned by resumeWithToolResults. Reset streamParam on switch so
  Translator produces fresh message_start/content_block_start events.

- Implement send-side H2 flow control: track server's initial window size
  and WINDOW_UPDATE increments; Write() blocks when window exhausted.
  Fixes RST_STREAM FLOW_CONTROL_ERROR on large requests (178KB+).

- Decode new InteractionUpdate fields: TurnEndedUpdate (field 14) as
  stream termination signal, HeartbeatUpdate (field 13) silently ignored,
  TokenDeltaUpdate (field 8) for token usage tracking.

- Include token usage in final stop chunk (prompt_tokens estimated from
  payload size, completion_tokens from accumulated TokenDeltaUpdate deltas)
  so Claude CLI status bar shows non-zero token counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:03:14 +08:00
kwz
c89d19b300 Preserve default transport settings for proxy clients 2026-03-25 15:33:09 +08:00
Luis Pater
1e6bc81cfd refactor(config): replace auto-update-panel with disable-auto-update-panel for clarity 2026-03-25 10:31:44 +08:00
Luis Pater
1a149475e0 Merge pull request #2293 from Xvvln/fix/management-asset-security
fix(security): harden management panel asset updater
2026-03-25 10:22:49 +08:00
Luis Pater
e5166841db Merge pull request #2310 from shellus/fix/claude-openai-system-top-level
fix: preserve OpenAI system messages as Claude top-level system
2026-03-25 10:21:18 +08:00
黄姜恒
19c52bcb60 feat: stash code 2026-03-25 10:14:14 +08:00
Luis Pater
bb9b2d1758 Merge pull request #2320 from cikichen/build/freebsd-support
build: add freebsd support for releases
2026-03-25 10:12:35 +08:00
Luis Pater
7fa527193c Merge pull request #453 from HeCHieh/fix/github-copilot-gpt54-responses
Fix GitHub Copilot gpt-5.4 endpoint routing
2026-03-25 09:45:23 +08:00
Luis Pater
ed0eb51b4d Merge pull request #450 from lwiles692/feature/add-codebuddy-support
feat(auth): add CodeBuddy-CN browser OAuth authentication support
2026-03-25 09:43:52 +08:00
Luis Pater
0e4f669c8b Merge branch 'router-for-me:main' into main 2026-03-25 09:38:34 +08:00
Luis Pater
76c064c729 Merge pull request #2335 from router-for-me/auth
Support batch upload and delete for auth files
2026-03-25 09:34:44 +08:00
Luis Pater
d2f652f436 Merge pull request #2333 from router-for-me/codex
feat(codex): pass through codex client identity headers
2026-03-25 09:34:09 +08:00
Luis Pater
6a452a54d5 Merge pull request #2316 from router-for-me/openai
Add per-model thinking support for OpenAI compatibility
2026-03-25 09:31:28 +08:00
hkfires
9e5693e74f feat(api): support batch auth file upload and delete 2026-03-25 09:20:17 +08:00
hkfires
528b1a2307 feat(codex): pass through codex client identity headers 2026-03-25 08:48:18 +08:00
Luis Pater
0cc978ec1d Merge pull request #2297 from router-for-me/readme
docs(readme): update japanese documentation links
2026-03-25 03:11:24 +08:00
simon
d312422ab4 build: add freebsd support to releases 2026-03-24 16:49:04 +08:00
hkfires
fee736933b feat(openai-compat): add per-model thinking support 2026-03-24 14:21:12 +08:00
GeJiaXiang
09c92aa0b5 fix: keep a fallback turn for system-only Claude inputs 2026-03-24 13:54:25 +08:00
GeJiaXiang
8c67b3ae64 test: verify remaining user message after system merge 2026-03-24 13:47:52 +08:00
GeJiaXiang
000e4ceb4e fix: map OpenAI system messages to Claude top-level system 2026-03-24 13:42:33 +08:00
hkfires
5c99846ecf docs(readme): update japanese documentation links 2026-03-24 09:47:01 +08:00
trph
cc32f5ff61 fix: unify Responses output indexes for streamed items 2026-03-24 08:59:09 +08:00
trph
fbff68b9e0 fix: preserve choice-aware output indexes for streamed tool calls 2026-03-24 08:54:43 +08:00
trph
7e1a543b79 fix: preserve separate streamed tool calls in Responses API 2026-03-24 08:51:15 +08:00
Luis Pater
d475aaba96 Fixed: #2274
fix(translator): omit null content fields in Codex OpenAI tool call responses
2026-03-24 01:00:57 +08:00
Luis Pater
1dc4ecb1b8 Merge pull request #456 from router-for-me/plus
v6.9.1
2026-03-24 00:43:35 +08:00
Luis Pater
1315f710f5 Merge branch 'main' into plus 2026-03-24 00:43:26 +08:00
Luis Pater
96f55570f7 Merge pull request #2282 from eltociear/add-ja-doc
docs: add Japanese README
2026-03-24 00:40:58 +08:00
Luis Pater
0906aeca87 Merge pull request #2254 from clcc2019/main
refactor: streamline usage reporting by consolidating record publishi…
2026-03-24 00:39:31 +08:00
Xvvln
7333619f15 fix: reject oversized downloads instead of truncating; warn on unverified fallback
- Read maxAssetDownloadSize+1 bytes and error if exceeded, preventing
  silent truncation that could write a broken management.html to disk
- Log explicit warning when fallback URL is used without digest
  verification, so users are aware of the reduced security guarantee
2026-03-24 00:27:44 +08:00
Luis Pater
97c0487add Merge pull request #2223 from cnrpman/fix/codex-responses-web-search-preview-compat
fix: normalize web_search_preview for codex responses
2026-03-24 00:25:37 +08:00
DragonFSKY
74b862d8b8 test(cliproxy): cover delete re-add stale state flow 2026-03-24 00:21:04 +08:00
Xvvln
2db8df8e38 fix(security): harden management panel asset updater
- Abort update when SHA256 digest mismatch is detected instead of
  logging a warning and proceeding (prevents MITM asset replacement)
- Cap asset download size to 10 MB via io.LimitReader (defense-in-depth
  against OOM from oversized responses)
- Add `auto-update-panel` config option (default: false) to make the
  periodic background updater opt-in; the panel is still downloaded
  on first access when missing, but no longer silently auto-updated
  every 3 hours unless explicitly enabled
2026-03-24 00:10:04 +08:00
Luis Pater
a576088d5f Merge pull request #2222 from kaitranntt/kai/fix/758-openai-proxy-alternating-model-support
fix: fall back on model support errors during auth rotation
2026-03-24 00:03:28 +08:00
Luis Pater
66ff916838 Merge pull request #2220 from xulongwu4/main
fix: normalize model name in TranslateRequest fallback to prevent prefix leak
2026-03-23 23:56:15 +08:00
Luis Pater
7b0453074e Merge pull request #2219 from beck-8/fix/context-done-race
fix: avoid data race when watching request cancellation
2026-03-23 22:57:21 +08:00
Luis Pater
a000eb523d Merge pull request #2213 from TTTPOB/ua-fix
feat(claude): stabilize device fingerprint across mixed Claude Code and cloaked clients
2026-03-23 22:53:51 +08:00
Luis Pater
18a4fedc7f Merge pull request #2126 from ailuntz/fix/watcher-auth-cache-memory
perf(watcher): reduce auth cache memory
2026-03-23 22:47:34 +08:00
Luis Pater
5d6cdccda0 Merge pull request #2268 from sususu98/fix/sanitize-tool-names
fix(translator): sanitize tool names for Gemini function_declarations compatibility
2026-03-23 21:42:22 +08:00
Luis Pater
1b7f4ac3e1 Merge pull request #2252 from sususu98/fix/antigravity-empty-thought-text
fix(antigravity): always include text field in thought parts to prevent Google 500
2026-03-23 21:41:25 +08:00
Luis Pater
afc1a5b814 Fixed: #2281
refactor(claude): centralize usage token calculation logic and add tests for cached token handling
2026-03-23 21:30:03 +08:00
Ikko Eltociear Ashimine
7ed38db54f docs: update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:43 +09:00
Ikko Eltociear Ashimine
28c10f4e69 docs: update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:32 +09:00
Ikko Eltociear Ashimine
6e12441a3b Update README_JA.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-23 16:57:19 +09:00
Ikko Ashimine
65c439c18d docs: add Japanese README 2026-03-23 15:23:18 +09:00
dslife2025
0ed2d16596 Merge branch 'router-for-me:main' into main 2026-03-23 09:50:43 +08:00
Supra4E8C
db335ac616 Merge pull request #2269 from router-for-me/auth-fix
fix(auth): ensure absolute paths for auth file handling
2026-03-22 22:53:44 +08:00
Luis Pater
f3c59165d7 Merge branch 'pr-454'
# Conflicts:
#	cmd/server/main.go
#	internal/translator/claude/openai/chat-completions/claude_openai_response.go
2026-03-22 22:52:46 +08:00
hechieh
e6690cb447 Refine GitHub Copilot endpoint selection
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:43:35 +08:00
hechieh
35907416b8 Fix GitHub Copilot gpt-5.4 endpoint routing
Amp-Thread-ID: https://ampcode.com/threads/T-019d14cd-bc90-70ce-b1ae-87bc97332650
Co-authored-by: Amp <amp@ampcode.com>
2026-03-22 19:05:44 +08:00
sususu98
e8bb350467 fix: extend tool name sanitization to all remaining Gemini-bound translators
Apply SanitizeFunctionName on request and RestoreSanitizedToolName on
response for: gemini/claude, gemini/openai/chat-completions,
gemini/openai/responses, antigravity/openai/chat-completions,
gemini-cli/openai/chat-completions.

Also update SanitizedToolNameMap to handle OpenAI format
(tools[].function.name) in addition to Claude format (tools[].name).
2026-03-22 14:06:46 +08:00
Supra4E8C
5331d51f27 fix(auth): ensure absolute paths for auth file handling 2026-03-22 13:58:16 +08:00
sususu98
755ca75879 fix: address review feedback - init ToolNameMap eagerly, log collisions, add collision test 2026-03-22 13:24:03 +08:00
sususu98
2398ebad55 fix(translator): sanitize tool names for Gemini function_declarations compatibility
Claude Code and MCP clients may send tool names containing characters
invalid for Gemini's function_declarations (e.g. '/', '@', spaces).
Sanitize on request via SanitizeFunctionName and restore original names
on response for both antigravity/claude and gemini-cli/claude translators.
2026-03-22 13:10:53 +08:00
clcc2019
c1bf298216 refactor: streamline usage reporting by consolidating record publishing logic
- Introduced a new method `buildRecord` in `usageReporter` to encapsulate record creation, improving code readability and maintainability.
- Added latency tracking to usage records, ensuring accurate reporting of request latencies.
- Updated tests to validate the inclusion of latency in usage records and ensure proper functionality of the new reporting structure.
2026-03-20 19:44:26 +08:00
sususu
e005208d76 fix(antigravity): always include text field in thought parts to prevent Google 500
When Claude sends redacted thinking with empty text, the translator
was omitting the "text" field from thought parts. Google Antigravity
API requires this field, causing 500 "Unknown Error" responses.

Verified: 129/129 error logs with empty thought → 500, 0/97 success
logs had empty thought. After fix: 0 new "Unknown Error" 500s.
2026-03-20 18:59:25 +08:00
Junyi Du
d1df70d02f chore: add codex builtin tool normalization logging 2026-03-20 14:08:37 +08:00
Luis Pater
f81acd0760 Merge pull request #2243 from router-for-me/oauth
Improve OAuth callback handling with async prompts
2026-03-20 12:35:44 +08:00
hkfires
636da4c932 refactor(auth): replace manual input handling with AsyncPrompt for callback URLs 2026-03-20 12:24:27 +08:00
hkfires
cccb77b552 fix(auth): avoid blocking oauth callback wait on prompt 2026-03-20 11:48:30 +08:00
Luis Pater
2bd646ad70 refactor: replace sjson.Set usage with sjson.SetBytes to optimize mutable JSON transformations 2026-03-19 17:58:54 +08:00
tpob
52c1fa025e fix(claude): learn official fingerprints after custom baselines 2026-03-19 13:59:41 +08:00
tpob
680105f84d fix(claude): refresh cached fingerprint after baseline upgrades 2026-03-19 13:28:58 +08:00
tpob
f7069e9548 fix(claude): pin stabilized OS arch to baseline 2026-03-19 13:07:16 +08:00
lwiles692
7275e99b41 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:59 +08:00
lwiles692
c28b65f849 Update internal/auth/codebuddy/codebuddy_auth.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-19 09:46:40 +08:00
Junyi Du
793840cdb4 fix: cover dated and nested codex web search aliases 2026-03-19 03:41:12 +08:00
Junyi Du
8f421de532 fix: handle sjson errors in codex tool normalization 2026-03-19 03:36:06 +08:00
Junyi Du
be2dd60ee7 fix: normalize web_search_preview for codex responses 2026-03-19 03:23:14 +08:00
Tam Nhu Tran
ea3e0b713e fix: harden pooled model-support fallback state 2026-03-18 13:19:20 -04:00
tpob
8179d5a8a4 fix(claude): avoid racy fingerprint downgrades 2026-03-19 01:03:41 +08:00
tpob
6fa7abe434 fix(claude): keep configured baseline above older fingerprints 2026-03-19 01:02:04 +08:00
Tam Nhu Tran
5135c22cd6 fix: fall back on model support errors during auth rotation 2026-03-18 12:43:45 -04:00
Longwu Ou
1e27990561 address PR review: log sjson error and add unit tests
- Log a warning instead of silently ignoring sjson.SetBytes errors in the TranslateRequest fallback path
  - Add registry_test.go with tests covering the fallback model normalization and verifying registered transforms take precedence
2026-03-18 12:43:40 -04:00
Longwu Ou
e1e9fc43c1 fix: normalize model name in TranslateRequest fallback to prevent prefix leak
When no request translator is registered for a format pair (e.g.
        openai-response → openai-response), TranslateRequest returned the raw
        payload unchanged. This caused client-side model prefixes (e.g.
        "copilot/gpt-5-mini") to leak into upstream requests, resulting in
        "The requested model is not supported" errors from providers.

        The fallback path now updates the "model" field in the payload to
        match the resolved model name before returning.
2026-03-18 12:30:22 -04:00
beck-8
b2921518ac fix: avoid data race when watching request cancellation 2026-03-19 00:15:52 +08:00
tpob
dd64adbeeb fix(claude): preserve legacy user agent overrides 2026-03-19 00:03:09 +08:00
tpob
616d41c06a fix(claude): restore legacy runtime OS arch fallback 2026-03-19 00:01:50 +08:00
tpob
e0e337aeb9 feat(claude): add switch for device profile stabilization 2026-03-18 19:31:59 +08:00
tpob
d52839fced fix: stabilize claude device fingerprint 2026-03-18 18:46:54 +08:00
Wei Lee
4022e69651 feat(auth): add CodeBuddy-CN browser OAuth authentication support 2026-03-18 17:50:12 +08:00
Luis Pater
56073ded69 Merge pull request #2200 from sususu98/feat/local-model-flag
feat: add -local-model flag to skip remote model catalog fetching
2026-03-18 10:58:07 +08:00
sususu98
9738a53f49 feat: add -local-model flag to skip remote model catalog fetching
When enabled, the server uses only the embedded models.json loaded at
init() and skips registry.StartModelsUpdater(), disabling the initial
remote fetch and 3-hour periodic refresh. The management panel
auto-updater (managementasset.StartAutoUpdater) is unaffected.
2026-03-18 10:48:03 +08:00
Luis Pater
be3f8dbf7e Merge pull request #2187 from Darley-Wey/fix/claude-disable-parallel-tool-calls
fix(claude): honor disable_parallel_tool_use
2026-03-17 21:06:08 +08:00
Darley
9c6c3612a8 fix(claude): read disable_parallel_tool_use from tool_choice 2026-03-17 19:35:41 +08:00
Darley
19e1a4447a fix(claude): honor disable_parallel_tool_use 2026-03-17 19:17:41 +08:00
Luis Pater
7c2ad4cda2 Merge branch 'router-for-me:main' into main 2026-03-17 00:09:43 +08:00
Luis Pater
fb95813fbf Merge pull request #2142 from Muran-prog/fix/strip-uniqueItems-gemini-2123
fix: strip uniqueItems from Gemini function_declarations (#2123)
2026-03-16 20:34:28 +08:00
Luis Pater
db63f9b5d6 Merge pull request #2162 from enieuwy/fix/responses-api-json-valid-check
fix: validate JSON before raw-embedding function call outputs in Responses API
2026-03-16 18:42:31 +08:00
Luis Pater
25f6c4a250 Merge pull request #2158 from sususu98/fix/antigravity-functionresponse-name
fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
2026-03-16 18:39:40 +08:00
enieuwy
b24ae74216 fix: validate JSON before raw-embedding function call outputs in Responses API
gjson.Parse() marks any string starting with { or [ as gjson.JSON type,
even when the content is not valid JSON (e.g. macOS plist format, truncated
tool results). This caused sjson.SetRaw to embed non-JSON content directly
into the Gemini API request payload, producing 400 errors.

Add json.Valid() check before using SetRaw to ensure only actually valid
JSON is embedded raw. Non-JSON content now falls through to sjson.Set
which properly escapes it as a JSON string.

Fixes #2161
2026-03-16 15:29:18 +08:00
Luis Pater
59ad8f40dc Merge pull request #2124 from RGBadmin/feat/auth-list-priority-note
feat(api): expose priority and note in GET /auth-files response
2026-03-16 12:31:11 +08:00
sususu98
ff03dc6a2c fix(antigravity): resolve empty functionResponse.name for toolu_* tool_use_id format
The Claude-to-Gemini translator derived function names by splitting
tool_use_id on "-", which produced empty strings for IDs with exactly
2 segments (e.g. toolu_tool-<uuid>). Replace the string-splitting
heuristic with a lookup map built from tool_use blocks during the
main processing loop, with fallback to the raw ID on miss.
2026-03-16 11:18:29 +08:00
Luis Pater
dc7187ca5b fix(websocket): pin only websocket-capable auth IDs and add corresponding test 2026-03-16 09:57:38 +08:00
Luis Pater
b1dcff778c Merge pull request #2141 from Muran-prog/fix/tool-calling-translation-2132
fix: skip empty assistant message in tool call translation (#2132)
2026-03-16 01:42:27 +08:00
RGBadmin
c1241a98e2 fix(api): restrict fallback note to string-typed JSON values
Only emit note in listAuthFilesFromDisk when the JSON value is actually
a string (gjson.String), matching the synthesizer/buildAuthFileEntry
behavior. Non-string values like numbers or booleans are now ignored
instead of being coerced via gjson.String().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 23:00:17 +08:00
RGBadmin
8d8f5970ee fix(api): fallback to Metadata for priority/note on uploaded auths
buildAuthFileEntry now falls back to reading priority/note from
auth.Metadata when Attributes lacks them. This covers auths registered
via UploadAuthFile which bypass the synthesizer and only populate
Metadata from the raw JSON.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 17:36:11 +08:00
RGBadmin
f90120f846 fix(api): propagate note to Gemini virtual auths and align priority parsing
- Read note from Attributes (consistent with priority) in buildAuthFileEntry,
  fixing missing note on Gemini multi-project virtual auth cards.
- Propagate note from primary to virtual auths in SynthesizeGeminiVirtualAuths,
  mirroring existing priority propagation.
- Sync note/priority writes to both Metadata and Attributes in PatchAuthFileFields,
  with refactored nil-check to reduce duplication (review feedback).
- Validate priority type in fallback disk-read path instead of coercing all values
  to 0 via gjson.Int(), aligning with the auth-manager code path.
- Add regression tests for note synthesis, virtual-auth note propagation, and
  end-to-end multi-project Gemini note inheritance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-15 16:47:01 +08:00
Muran-prog
0b94d36c4a test: use exact match for tool name assertion
Address review feedback - drop function.name fallback and
strings.Contains in favor of direct == comparison.
2026-03-14 21:45:28 +02:00
Muran-prog
152c310bb7 test: add uniqueItems stripping test
Covers the fix from the previous commit — verifies uniqueItems is
removed from the schema and moved to the description hint.
2026-03-14 21:22:14 +02:00
Muran-prog
f6bbca35ab fix: strip uniqueItems from Gemini function_declarations (#2123)
Gemini API rejects uniqueItems in tool schemas with 400. Add it to
unsupportedConstraints alongside minItems/maxItems where it belongs.

Same class of fix as #1424 and #1531.
2026-03-14 21:18:06 +02:00
Muran-prog
c8cee6a209 fix: skip empty assistant message in tool call translation (#2132)
When assistant has tool_calls but no text content, the translator
emitted an empty message into the Responses API input array before
function_call items. The API then couldn't match function_call_output
to its function_call by call_id, returning:

  No tool output found for function call ...

Only emit assistant messages that have content parts. Tool-call-only
messages now produce function_call items directly.

Added 9 tests for tool calling translation covering single/parallel
calls, multi-turn conversations, name shortening, empty content
edge cases, and call_id integrity.
2026-03-14 21:01:01 +02:00
DragonFSKY
5c817a9b42 fix(auth): prevent stale ModelStates inheritance from disabled auth entries
When an auth file is deleted and re-created with the same path/ID, the
new auth could inherit stale ModelStates (cooldown/backoff) from the
previously disabled entry, preventing it from being routed.

Gate runtime state inheritance (ModelStates, LastRefreshedAt,
NextRefreshAfter) on both existing and incoming auth being non-disabled
in Manager.Update and Service.applyCoreAuthAddOrUpdate.

Closes #2061
2026-03-14 23:46:23 +08:00
luxvtz
5da0decef6 Improve GitLab Duo gateway compatibility 2026-03-14 03:18:43 -07:00
RGBadmin
5b6342e6ac feat(api): expose priority and note fields in GET /auth-files list response
The list endpoint previously omitted priority and note, which are stored
inside each auth file's JSON content. This adds them to both the normal
(auth-manager) and fallback (disk-read) code paths, and extends
PATCH /auth-files/fields to support writing the note field.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 14:47:31 +08:00
ailuntz
c3762328a5 perf(watcher): reduce auth cache memory 2026-03-10 16:27:10 +08:00
171 changed files with 16774 additions and 3681 deletions

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
# Binaries
cli-proxy-api
cliproxy
/server
*.exe

View File

@@ -8,6 +8,7 @@ builds:
- linux
- windows
- darwin
- freebsd
goarch:
- amd64
- arm64

View File

@@ -1,6 +1,6 @@
# CLIProxyAPI Plus
[English](README.md) | 中文
[English](README.md) | 中文 | [日本語](README_JA.md)
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。

199
README_JA.md Normal file
View File

@@ -0,0 +1,199 @@
# CLI Proxy API
[English](README.md) | [中文](README_CN.md) | 日本語
CLI向けのOpenAI/Gemini/Claude/Codex互換APIインターフェースを提供するプロキシサーバーです。
OAuth経由でOpenAI CodexGPTモデルおよびClaude Codeもサポートしています。
ローカルまたはマルチアカウントのCLIアクセスを、OpenAIResponses含む/Gemini/Claude互換のクライアントやSDKで利用できます。
## スポンサー
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7およびGLM-5はProユーザーのみ利用可能モデルを10以上の人気AIコーディングツールClaude Code、Cline、Roo Codeなどで利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
GLM CODING PLANを10%割引で取得https://z.ai/subscribe?ic=8JVLJQFSKB
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>PackyCodeのスポンサーシップに感謝しますPackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝しますAICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引がありますCLIProxyAPIユーザー向けの特別特典<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたしますBmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割90% OFF</b> という驚異的な価格でご利用いただけます!</td>
</tr>
<tr>
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
<td>LingtrueAPIのスポンサーシップに感謝しますLingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
</tr>
</tbody>
</table>
## 概要
- CLIモデル向けのOpenAI/Gemini/Claude互換APIエンドポイント
- OAuthログインによるOpenAI CodexサポートGPTモデル
- OAuthログインによるClaude Codeサポート
- OAuthログインによるQwen Codeサポート
- OAuthログインによるiFlowサポート
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
- ストリーミングおよび非ストリーミングレスポンス
- 関数呼び出し/ツールのサポート
- マルチモーダル入力サポート(テキストと画像)
- ラウンドロビン負荷分散による複数アカウント対応Gemini、OpenAI、Claude、QwenおよびiFlow
- シンプルなCLI認証フローGemini、OpenAI、Claude、QwenおよびiFlow
- Generative Language APIキーのサポート
- AI Studioビルドのマルチアカウント負荷分散
- Gemini CLIのマルチアカウント負荷分散
- Claude Codeのマルチアカウント負荷分散
- Qwen Codeのマルチアカウント負荷分散
- iFlowのマルチアカウント負荷分散
- OpenAI Codexのマルチアカウント負荷分散
- 設定によるOpenAI互換アップストリームプロバイダーOpenRouter
- プロキシ埋め込み用の再利用可能なGo SDK`docs/sdk-usage.md`を参照)
## はじめに
CLIProxyAPIガイド[https://help.router-for.me/](https://help.router-for.me/)
## 管理API
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
## Amp CLIサポート
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます
- Ampの APIパターン用のプロバイダールートエイリアス`/api/provider/{provider}/v1...`
- OAuth認証およびアカウント機能用の管理プロキシ
- 自動ルーティングによるスマートモデルフォールバック
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5``claude-sonnet-4`
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
## SDKドキュメント
- 使い方:[docs/sdk-usage.md](docs/sdk-usage.md)
- 上級(エグゼキューターとトランスレーター):[docs/sdk-advanced.md](docs/sdk-advanced.md)
- アクセス:[docs/sdk-access.md](docs/sdk-access.md)
- ウォッチャー:[docs/sdk-watcher.md](docs/sdk-watcher.md)
- カスタムプロバイダーの例:`examples/custom-provider`
## コントリビューション
コントリビューションを歓迎しますお気軽にPull Requestを送ってください。
1. リポジトリをフォーク
2. フィーチャーブランチを作成(`git checkout -b feature/amazing-feature`
3. 変更をコミット(`git commit -m 'Add some amazing feature'`
4. ブランチにプッシュ(`git push origin feature/amazing-feature`
5. Pull Requestを作成
## 関連プロジェクト
CLIProxyAPIをベースにした以下のプロジェクトがあります
### [vibeproxy](https://github.com/automazeio/vibeproxy)
macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTのサブスクリプションをAIコーディングツールで使用可能 - APIキー不要
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデルGemini、Codex、Antigravityを即座に切り替えるCLIラッパー - APIキー不要
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
CLIProxyAPI管理用のmacOSネイティブGUIOAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
### [Quotio](https://github.com/nguyenphutrong/quotio)
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
### [CodMate](https://github.com/loocor/CodMate)
CLI AIセッションCodex、Claude Code、Gemini CLIを管理するmacOS SwiftUIネイティブアプリ。統合プロバイダー管理、Gitレビュー、プロジェクト整理、グローバル検索、ターミナル統合機能を搭載。CLIProxyAPIと統合し、Codex、Claude、Gemini、Antigravity、Qwen CodeのOAuth認証を提供。単一のプロキシエンドポイントを通じた組み込みおよびサードパーティプロバイダーの再ルーティングに対応 - OAuthプロバイダーではAPIキー不要
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
TUI、システムトレイ、マルチプロバイダーOAuthを備えたWindows向けCLIProxyAPIフォーク - AIコーディングツール用、APIキー不要
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
Claude Codeモデルを素早く切り替えるVSCode拡張機能。バックエンドとしてCLIProxyAPIを統合し、バックグラウンドでの自動ライフサイクル管理を搭載
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
CLIProxyAPIを使用してAIコーディングアシスタントのクォータを監視するTauri + React製のWindowsデスクトップアプリ。Gemini、Claude、OpenAI Codex、Antigravityアカウントの使用量をリアルタイムダッシュボード、システムトレイ統合、ワンクリックプロキシコントロールで追跡 - APIキー不要
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
CLIProxyAPI向けの軽量Web管理パネル。ヘルスチェック、リソース監視、リアルタイムログ、自動更新、リクエスト統計、料金表示機能を搭載。ワンクリックインストールとsystemdサービスに対応
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
PowerShellスクリプトで実装されたWindowsトレイアプリケーション。サードパーティライブラリに依存せず、ショートカットの自動作成、サイレント実行、パスワード管理、チャネル切り替えMain / Plus、自動ダウンロードおよび自動更新に対応
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君はAIプログラミングアシスタントを管理するクロスプラットフォームデスクトップアプリケーションで、macOS、Windows、Linuxシステムに対応。Claude Code、Gemini CLI、OpenAI Codex、Qwen Codeなどのコーディングツールを統合管理し、ローカルプロキシによるマルチアカウントクォータ追跡とワンクリック設定が可能
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
Next.js、React、PostgreSQLで構築されたCLIProxyAPI用のモダンなWebベース管理ダッシュボード。リアルタイムログストリーミング、構造化された設定編集、APIキー管理、Claude/Gemini/Codex向けOAuthプロバイダー統合、使用量分析、コンテナ管理、コンパニオンプラグインによるOpenCodeとの設定同期機能を搭載 - 手動でのYAML編集は不要
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
New API互換リレーサイトアカウントをワンストップで管理するブラウザ拡張機能。残高と使用量のダッシュボード、自動チェックイン、一般的なアプリへのワンクリックキーエクスポート、ページ内API可用性テスト、チャネル/モデルの同期とリダイレクト機能を搭載。Management APIを通じてCLIProxyAPIと統合し、ワンクリックでプロバイダーのインポートと設定同期が可能
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AIは制限された環境向けに特別に設計されたAIアシスタントツールです。ウィンドウや痕跡のないステルス動作モードを提供し、LANローカルエリアネットワークを介したクロスデバイスAI質疑応答のインタラクションと制御を可能にします。本質的には「画面/音声キャプチャ + AI推論 + 低摩擦デリバリー」の自動化コラボレーションレイヤーであり、制御されたデバイスや制限された環境でアプリケーション横断的にAIアシスタントを没入的に使用できるようユーザーを支援します。
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## その他の選択肢
以下のプロジェクトはCLIProxyAPIの移植版またはそれに触発されたものです
### [9Router](https://github.com/decolua/9router)
CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡単で、フォーマット変換OpenAI/Claude/Gemini/Ollama、自動フォールバック付きコンボシステム、指数バックオフ付きマルチアカウント管理、Next.js Webダッシュボード、CLIツールCursor、Claude Code、Cline、RooCodeのサポートをゼロから構築 - APIキー不要
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
コーディングを止めない。無料および低コストのAIモデルへのスマートルーティングと自動フォールバック。
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイですスマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
## ライセンス
本プロジェクトはMITライセンスの下でライセンスされています - 詳細は[LICENSE](LICENSE)ファイルを参照してください。

BIN
assets/bmoplus.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

BIN
assets/lingtrue.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

20
cmd/mcpdebug/main.go Normal file
View File

@@ -0,0 +1,20 @@
package main
import (
"encoding/hex"
"fmt"
"os"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
// Encode MCP result with empty execId
resultBytes := cursorproto.EncodeExecMcpResult(1, "", `{"test": "data"}`, false)
fmt.Printf("Result protobuf hex: %s\n", hex.EncodeToString(resultBytes))
fmt.Printf("Result length: %d bytes\n", len(resultBytes))
// Write to file for analysis
os.WriteFile("mcp_result.bin", resultBytes)
fmt.Println("Wrote mcp_result.bin")
}

32
cmd/protocheck/main.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"fmt"
cursorproto "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor/proto"
)
func main() {
ecm := cursorproto.NewMsg("ExecClientMessage")
// Try different field names
names := []string{
"mcp_result", "mcpResult", "McpResult", "MCP_RESULT",
"shell_result", "shellResult",
}
for _, name := range names {
fd := ecm.Descriptor().Fields().ByName(name)
if fd != nil {
fmt.Printf("Found field %q: number=%d, kind=%s\n", name, fd.Number(), fd.Kind())
} else {
fmt.Printf("Field %q NOT FOUND\n", name)
}
}
// List all fields
fmt.Println("\nAll fields in ExecClientMessage:")
for i := 0; i < ecm.Descriptor().Fields().Len(); i++ {
f := ecm.Descriptor().Fields().Get(i)
fmt.Printf(" %d: %q (number=%d)\n", i, f.Name(), f.Number())
}
}

View File

@@ -85,6 +85,7 @@ func main() {
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var cursorLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
@@ -95,6 +96,7 @@ func main() {
var kiroIDCRegion string
var kiroIDCFlow string
var githubCopilotLogin bool
var codeBuddyLogin bool
var projectID string
var vertexImport string
var configPath string
@@ -103,6 +105,7 @@ func main() {
var standalone bool
var noIncognito bool
var useIncognito bool
var localModel bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -121,6 +124,7 @@ func main() {
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&cursorLogin, "cursor-login", false, "Login to Cursor using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
@@ -131,12 +135,14 @@ func main() {
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.BoolVar(&codeBuddyLogin, "codebuddy-login", false, "Login to CodeBuddy using browser OAuth flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -514,6 +520,9 @@ func main() {
} else if githubCopilotLogin {
// Handle GitHub Copilot login
cmd.DoGitHubCopilotLogin(cfg, options)
} else if codeBuddyLogin {
// Handle CodeBuddy login
cmd.DoCodeBuddyLogin(cfg, options)
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
@@ -537,6 +546,8 @@ func main() {
cmd.DoGitLabTokenLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if cursorLogin {
cmd.DoCursorLogin(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
@@ -578,11 +589,16 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
if localModel && (!tuiMode || standalone) {
log.Info("Local model mode: using embedded model catalog, remote model updates disabled")
}
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
if !localModel {
registry.StartModelsUpdater(context.Background())
}
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
@@ -655,7 +671,9 @@ func main() {
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
if !localModel {
registry.StartModelsUpdater(context.Background())
}
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)

View File

@@ -25,6 +25,10 @@ remote-management:
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# Disable automatic periodic background updates of the management panel from GitHub (default: false).
# When enabled, the panel is only downloaded on first access if missing, and never auto-updated afterward.
# disable-auto-update-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
@@ -175,12 +179,19 @@ nonstream-keepalive-interval: 0
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# In legacy mode, user-agent/package-version/runtime-version/timeout are used as fallbacks
# when the client omits them, while OS/arch remain runtime-derived. When
# stabilize-device-profile is enabled, OS/arch stay pinned to the baseline values below,
# while user-agent/package-version/runtime-version seed a software fingerprint that can
# still upgrade to newer official Claude client versions.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# os: "MacOS"
# arch: "arm64"
# timeout: "600"
# stabilize-device-profile: false # optional, default false; set true to enable per-auth/API-key fingerprint pinning
# Default headers for Codex OAuth model requests.
# These are used only for file-backed/OAuth Codex requests when the client
@@ -231,7 +242,9 @@ nonstream-keepalive-interval: 0
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# alias: "kimi-k2" # The alias used in the API.
# thinking: # optional: omit to default to levels ["low","medium","high"]
# levels: ["low", "medium", "high"]
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
@@ -300,6 +313,10 @@ nonstream-keepalive-interval: 0
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
# you select the protocol surface, but inference backend selection can still follow the resolved
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias:
# antigravity:

View File

@@ -52,11 +52,11 @@ func init() {
sdktr.Register(fOpenAI, fMyProv,
func(model string, raw []byte, stream bool) []byte { return raw },
sdktr.ResponseTransform{
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
return []string{string(raw)}
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) [][]byte {
return [][]byte{raw}
},
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
return string(raw)
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []byte {
return raw
},
},
)

2
go.mod
View File

@@ -91,8 +91,8 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect

View File

@@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"net/http"
"net/url"
@@ -28,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
cursorauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cursor"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
@@ -66,8 +68,10 @@ type callbackForwarder struct {
}
var (
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
errAuthFileMustBeJSON = errors.New("auth file must be .json")
errAuthFileNotFound = errors.New("auth file not found")
)
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
@@ -341,6 +345,21 @@ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
emailValue := gjson.GetBytes(data, "email").String()
fileData["type"] = typeValue
fileData["email"] = emailValue
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
switch pv.Type {
case gjson.Number:
fileData["priority"] = int(pv.Int())
case gjson.String:
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
fileData["priority"] = parsed
}
}
}
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
fileData["note"] = trimmed
}
}
}
files = append(files, fileData)
@@ -424,6 +443,37 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
if parsed, err := strconv.Atoi(p); err == nil {
entry["priority"] = parsed
}
} else if auth.Metadata != nil {
if rawPriority, ok := auth.Metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
entry["priority"] = int(v)
case int:
entry["priority"] = v
case string:
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
entry["priority"] = parsed
}
}
}
}
// Expose note from Attributes (set by synthesizer from JSON "note" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
entry["note"] = note
} else if auth.Metadata != nil {
if rawNote, ok := auth.Metadata["note"].(string); ok {
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
entry["note"] = trimmed
}
}
}
return entry
}
@@ -501,10 +551,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
}
func isUnsafeAuthFileName(name string) bool {
if strings.TrimSpace(name) == "" {
return true
}
if strings.ContainsAny(name, "/\\") {
return true
}
if filepath.VolumeName(name) != "" {
return true
}
return false
}
// Download single auth file by name
func (h *Handler) DownloadAuthFile(c *gin.Context) {
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -533,36 +596,61 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
return
}
ctx := c.Request.Context()
if file, err := c.FormFile("file"); err == nil && file != nil {
name := filepath.Base(file.Filename)
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "file must be .json"})
return
}
dst := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
return
}
data, errRead := os.ReadFile(dst)
if errRead != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
return
}
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
c.JSON(500, gin.H{"error": errReg.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
fileHeaders, errMultipart := h.multipartAuthFileHeaders(c)
if errMultipart != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid multipart form: %v", errMultipart)})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
if len(fileHeaders) == 1 {
if _, errUpload := h.storeUploadedAuthFile(ctx, fileHeaders[0]); errUpload != nil {
if errors.Is(errUpload, errAuthFileMustBeJSON) {
c.JSON(http.StatusBadRequest, gin.H{"error": "file must be .json"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": errUpload.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
if len(fileHeaders) > 1 {
uploaded := make([]string, 0, len(fileHeaders))
failed := make([]gin.H, 0)
for _, file := range fileHeaders {
name, errUpload := h.storeUploadedAuthFile(ctx, file)
if errUpload != nil {
failureName := ""
if file != nil {
failureName = filepath.Base(file.Filename)
}
msg := errUpload.Error()
if errors.Is(errUpload, errAuthFileMustBeJSON) {
msg = "file must be .json"
}
failed = append(failed, gin.H{"name": failureName, "error": msg})
continue
}
uploaded = append(uploaded, name)
}
if len(failed) > 0 {
c.JSON(http.StatusMultiStatus, gin.H{
"status": "partial",
"uploaded": len(uploaded),
"files": uploaded,
"failed": failed,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "uploaded": len(uploaded), "files": uploaded})
return
}
if c.ContentType() == "multipart/form-data" {
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
return
}
name := strings.TrimSpace(c.Query("name"))
if isUnsafeAuthFileName(name) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
@@ -575,17 +663,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
return
}
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
if err = h.writeAuthFile(ctx, filepath.Base(name), data); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
@@ -632,11 +710,182 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
names, errNames := requestedAuthFileNamesForDelete(c)
if errNames != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errNames.Error()})
return
}
if len(names) == 0 {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if len(names) == 1 {
if _, status, errDelete := h.deleteAuthFileByName(ctx, names[0]); errDelete != nil {
c.JSON(status, gin.H{"error": errDelete.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
deletedFiles := make([]string, 0, len(names))
failed := make([]gin.H, 0)
for _, name := range names {
deletedName, _, errDelete := h.deleteAuthFileByName(ctx, name)
if errDelete != nil {
failed = append(failed, gin.H{"name": name, "error": errDelete.Error()})
continue
}
deletedFiles = append(deletedFiles, deletedName)
}
if len(failed) > 0 {
c.JSON(http.StatusMultiStatus, gin.H{
"status": "partial",
"deleted": len(deletedFiles),
"files": deletedFiles,
"failed": failed,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": len(deletedFiles), "files": deletedFiles})
}
func (h *Handler) multipartAuthFileHeaders(c *gin.Context) ([]*multipart.FileHeader, error) {
if h == nil || c == nil || c.ContentType() != "multipart/form-data" {
return nil, nil
}
form, err := c.MultipartForm()
if err != nil {
return nil, err
}
if form == nil || len(form.File) == 0 {
return nil, nil
}
keys := make([]string, 0, len(form.File))
for key := range form.File {
keys = append(keys, key)
}
sort.Strings(keys)
headers := make([]*multipart.FileHeader, 0)
for _, key := range keys {
headers = append(headers, form.File[key]...)
}
return headers, nil
}
func (h *Handler) storeUploadedAuthFile(ctx context.Context, file *multipart.FileHeader) (string, error) {
if file == nil {
return "", fmt.Errorf("no file uploaded")
}
name := filepath.Base(strings.TrimSpace(file.Filename))
if !strings.HasSuffix(strings.ToLower(name), ".json") {
return "", errAuthFileMustBeJSON
}
src, err := file.Open()
if err != nil {
return "", fmt.Errorf("failed to open uploaded file: %w", err)
}
defer src.Close()
data, err := io.ReadAll(src)
if err != nil {
return "", fmt.Errorf("failed to read uploaded file: %w", err)
}
if err := h.writeAuthFile(ctx, name, data); err != nil {
return "", err
}
return name, nil
}
func (h *Handler) writeAuthFile(ctx context.Context, name string, data []byte) error {
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
auth, err := h.buildAuthFromFileData(dst, data)
if err != nil {
return err
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
return fmt.Errorf("failed to write file: %w", errWrite)
}
if err := h.upsertAuthRecord(ctx, auth); err != nil {
return err
}
return nil
}
func requestedAuthFileNamesForDelete(c *gin.Context) ([]string, error) {
if c == nil {
return nil, nil
}
names := uniqueAuthFileNames(c.QueryArray("name"))
if len(names) > 0 {
return names, nil
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body")
}
body = bytes.TrimSpace(body)
if len(body) == 0 {
return nil, nil
}
var objectBody struct {
Name string `json:"name"`
Names []string `json:"names"`
}
if body[0] == '[' {
var arrayBody []string
if err := json.Unmarshal(body, &arrayBody); err != nil {
return nil, fmt.Errorf("invalid request body")
}
return uniqueAuthFileNames(arrayBody), nil
}
if err := json.Unmarshal(body, &objectBody); err != nil {
return nil, fmt.Errorf("invalid request body")
}
out := make([]string, 0, len(objectBody.Names)+1)
if strings.TrimSpace(objectBody.Name) != "" {
out = append(out, objectBody.Name)
}
out = append(out, objectBody.Names...)
return uniqueAuthFileNames(out), nil
}
func uniqueAuthFileNames(names []string) []string {
if len(names) == 0 {
return nil
}
seen := make(map[string]struct{}, len(names))
out := make([]string, 0, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
out = append(out, name)
}
return out
}
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
name = strings.TrimSpace(name)
if isUnsafeAuthFileName(name) {
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
}
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
targetID := ""
@@ -653,22 +902,19 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
}
if errRemove := os.Remove(targetPath); errRemove != nil {
if os.IsNotExist(errRemove) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
return filepath.Base(name), http.StatusNotFound, errAuthFileNotFound
}
return
return filepath.Base(name), http.StatusInternalServerError, fmt.Errorf("failed to remove file: %w", errRemove)
}
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
return
return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord
}
if targetID != "" {
h.disableAuth(ctx, targetID)
} else {
h.disableAuth(ctx, targetPath)
}
c.JSON(200, gin.H{"status": "ok"})
return filepath.Base(name), http.StatusOK, nil
}
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
@@ -702,10 +948,25 @@ func (h *Handler) authIDForPath(path string) string {
if path == "" {
return ""
}
path = filepath.Clean(path)
if !filepath.IsAbs(path) {
if abs, errAbs := filepath.Abs(path); errAbs == nil {
path = abs
}
}
id := path
if h != nil && h.cfg != nil {
authDir := strings.TrimSpace(h.cfg.AuthDir)
if resolvedAuthDir, errResolve := util.ResolveAuthDir(authDir); errResolve == nil && resolvedAuthDir != "" {
authDir = resolvedAuthDir
}
if authDir != "" {
authDir = filepath.Clean(authDir)
if !filepath.IsAbs(authDir) {
if abs, errAbs := filepath.Abs(authDir); errAbs == nil {
authDir = abs
}
}
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
id = rel
}
@@ -722,19 +983,27 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
if h.authManager == nil {
return nil
}
auth, err := h.buildAuthFromFileData(path, data)
if err != nil {
return err
}
return h.upsertAuthRecord(ctx, auth)
}
func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Auth, error) {
if path == "" {
return fmt.Errorf("auth path is empty")
return nil, fmt.Errorf("auth path is empty")
}
if data == nil {
var err error
data, err = os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read auth file: %w", err)
return nil, fmt.Errorf("failed to read auth file: %w", err)
}
}
metadata := make(map[string]any)
if err := json.Unmarshal(data, &metadata); err != nil {
return fmt.Errorf("invalid auth file: %w", err)
return nil, fmt.Errorf("invalid auth file: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
@@ -768,13 +1037,25 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
if hasLastRefresh {
auth.LastRefreshedAt = lastRefresh
}
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
if h != nil && h.authManager != nil {
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
}
return auth, nil
}
func (h *Handler) upsertAuthRecord(ctx context.Context, auth *coreauth.Auth) error {
if h == nil || h.authManager == nil || auth == nil {
return nil
}
if existing, ok := h.authManager.GetByID(auth.ID); ok {
auth.CreatedAt = existing.CreatedAt
_, err := h.authManager.Update(ctx, auth)
return err
}
@@ -848,7 +1129,7 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
@@ -860,6 +1141,7 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
Note *string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
@@ -902,14 +1184,32 @@ func (h *Handler) PatchAuthFileFields(c *gin.Context) {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
if req.Priority != nil {
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
delete(targetAuth.Attributes, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
}
}
if req.Note != nil {
trimmedNote := strings.TrimSpace(*req.Note)
if trimmedNote == "" {
delete(targetAuth.Metadata, "note")
delete(targetAuth.Attributes, "note")
} else {
targetAuth.Metadata["note"] = trimmedNote
targetAuth.Attributes["note"] = trimmedNote
}
}
changed = true
}
@@ -1838,9 +2138,6 @@ func (h *Handler) RequestGitLabToken(c *gin.Context) {
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
@@ -3408,3 +3705,84 @@ func (h *Handler) RequestKiloToken(c *gin.Context) {
"verification_uri": resp.VerificationURL,
})
}
// RequestCursorToken initiates the Cursor PKCE authentication flow.
// Supports multiple accounts via ?label=xxx query parameter.
// The user opens the returned URL in a browser, logs in, and the server polls
// until the authentication completes.
func (h *Handler) RequestCursorToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
label := strings.TrimSpace(c.Query("label"))
log.Infof("Initializing Cursor authentication (label=%q)...", label)
authParams, err := cursorauth.GenerateAuthParams()
if err != nil {
log.Errorf("Failed to generate Cursor auth params: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate auth params"})
return
}
state := fmt.Sprintf("cur-%d", time.Now().UnixNano())
RegisterOAuthSession(state, "cursor")
go func() {
log.Info("Waiting for Cursor authentication...")
log.Infof("Open this URL in your browser: %s", authParams.LoginURL)
tokens, errPoll := cursorauth.PollForAuth(ctx, authParams.UUID, authParams.Verifier)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed: "+errPoll.Error())
log.Errorf("Cursor authentication failed: %v", errPoll)
return
}
// Build metadata
metadata := map[string]any{
"type": "cursor",
"access_token": tokens.AccessToken,
"refresh_token": tokens.RefreshToken,
"timestamp": time.Now().UnixMilli(),
}
// Extract expiry and account identity from JWT
expiry := cursorauth.GetTokenExpiry(tokens.AccessToken)
if !expiry.IsZero() {
metadata["expires_at"] = expiry.Format(time.RFC3339)
}
// Auto-identify account from JWT sub claim for multi-account support
sub := cursorauth.ParseJWTSub(tokens.AccessToken)
subHash := cursorauth.SubToShortHash(sub)
if sub != "" {
metadata["sub"] = sub
}
fileName := cursorauth.CredentialFileName(label, subHash)
displayLabel := cursorauth.DisplayLabel(label, subHash)
record := &coreauth.Auth{
ID: fileName,
Provider: "cursor",
FileName: fileName,
Label: displayLabel,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save Cursor tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save tokens")
return
}
log.Infof("Cursor authentication successful! Token saved to %s", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("cursor")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authParams.LoginURL,
"state": state,
})
}

View File

@@ -0,0 +1,197 @@
package management
import (
"bytes"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestUploadAuthFile_BatchMultipart(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
files := []struct {
name string
content string
}{
{name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"])
}
for _, file := range files {
fullPath := filepath.Join(authDir, file.name)
data, err := os.ReadFile(fullPath)
if err != nil {
t.Fatalf("expected uploaded file %s to exist: %v", file.name, err)
}
if string(data) != file.content {
t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data))
}
}
auths := manager.List()
if len(auths) != len(files) {
t.Fatalf("expected %d auth entries, got %d", len(files), len(auths))
}
}
func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
existingName := "alpha.json"
existingContent := `{"type":"codex","email":"alpha@example.com"}`
if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil {
t.Fatalf("failed to seed existing auth file: %v", err)
}
files := []struct {
name string
content string
}{
{name: existingName, content: `{"type":"codex"`},
{name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
for _, file := range files {
part, err := writer.CreateFormFile("file", file.name)
if err != nil {
t.Fatalf("failed to create multipart file: %v", err)
}
if _, err = part.Write([]byte(file.content)); err != nil {
t.Fatalf("failed to write multipart content: %v", err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("failed to close multipart writer: %v", err)
}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body)
req.Header.Set("Content-Type", writer.FormDataContentType())
ctx.Request = req
h.UploadAuthFile(ctx)
if rec.Code != http.StatusMultiStatus {
t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String())
}
data, err := os.ReadFile(filepath.Join(authDir, existingName))
if err != nil {
t.Fatalf("expected existing auth file to remain readable: %v", err)
}
if string(data) != existingContent {
t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data))
}
betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json"))
if err != nil {
t.Fatalf("expected valid auth file to be created: %v", err)
}
if string(betaData) != files[1].content {
t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData))
}
}
func TestDeleteAuthFile_BatchQuery(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
files := []string{"alpha.json", "beta.json"}
for _, name := range files {
if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil {
t.Fatalf("failed to write auth file %s: %v", name, err)
}
}
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(
http.MethodDelete,
"/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]),
nil,
)
ctx.Request = req
h.DeleteAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) {
t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"])
}
for _, name := range files {
if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) {
t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err)
}
}
}

View File

@@ -0,0 +1,62 @@
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "download-user.json"
expected := []byte(`{"type":"codex"}`)
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
if got := rec.Body.Bytes(); string(got) != string(expected) {
t.Fatalf("unexpected download content: %q", string(got))
}
}
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
for _, name := range []string{
"../external/secret.json",
`..\\external\\secret.json`,
"nested/secret.json",
`nested\\secret.json`,
} {
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
}
}
}

View File

@@ -0,0 +1,51 @@
//go:build windows
package management
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
tempDir := t.TempDir()
authDir := filepath.Join(tempDir, "auth")
externalDir := filepath.Join(tempDir, "external")
if err := os.MkdirAll(authDir, 0o700); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
if err := os.MkdirAll(externalDir, 0o700); err != nil {
t.Fatalf("failed to create external dir: %v", err)
}
secretName := "secret.json"
secretPath := filepath.Join(externalDir, secretName)
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
t.Fatalf("failed to write external file: %v", err)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(
http.MethodGet,
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
nil,
)
h.DownloadAuthFile(ctx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
}
}

View File

@@ -682,6 +682,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/cursor-auth-url", s.mgmt.RequestCursorToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)

View File

@@ -0,0 +1,335 @@
package codebuddy
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const (
BaseURL = "https://copilot.tencent.com"
DefaultDomain = "www.codebuddy.cn"
UserAgent = "CLI/2.63.2 CodeBuddy/2.63.2"
codeBuddyStatePath = "/v2/plugin/auth/state"
codeBuddyTokenPath = "/v2/plugin/auth/token"
codeBuddyRefreshPath = "/v2/plugin/auth/token/refresh"
pollInterval = 5 * time.Second
maxPollDuration = 5 * time.Minute
codeLoginPending = 11217
codeSuccess = 0
)
type CodeBuddyAuth struct {
httpClient *http.Client
cfg *config.Config
baseURL string
}
func NewCodeBuddyAuth(cfg *config.Config) *CodeBuddyAuth {
httpClient := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
httpClient = util.SetProxy(&cfg.SDKConfig, httpClient)
}
return &CodeBuddyAuth{httpClient: httpClient, cfg: cfg, baseURL: BaseURL}
}
// AuthState holds the state and auth URL returned by the auth state API.
type AuthState struct {
State string
AuthURL string
}
// FetchAuthState calls POST /v2/plugin/auth/state?platform=CLI to get the state and login URL.
func (a *CodeBuddyAuth) FetchAuthState(ctx context.Context) (*AuthState, error) {
stateURL := fmt.Sprintf("%s%s?platform=CLI", a.baseURL, codeBuddyStatePath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, stateURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create auth state request: %w", err)
}
requestID := uuid.NewString()
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", "copilot.tencent.com")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Request-ID", requestID)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: auth state request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy auth state: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read auth state response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: auth state request returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
State string `json:"state"`
AuthURL string `json:"authUrl"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse auth state response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: auth state request failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil || result.Data.State == "" || result.Data.AuthURL == "" {
return nil, fmt.Errorf("codebuddy: auth state response missing state or authUrl")
}
return &AuthState{
State: result.Data.State,
AuthURL: result.Data.AuthURL,
}, nil
}
type pollResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
RequestID string `json:"requestId"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
// doPollRequest performs a single polling request, safely reading and closing the response body
func (a *CodeBuddyAuth) doPollRequest(ctx context.Context, pollURL string) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("%w: %v", ErrTokenFetchFailed, err)
}
a.applyPollHeaders(req)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, 0, err
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy poll: close body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("codebuddy poll: failed to read response body: %w", err)
}
return body, resp.StatusCode, nil
}
// PollForToken polls until the user completes browser authorization and returns auth data.
func (a *CodeBuddyAuth) PollForToken(ctx context.Context, state string) (*CodeBuddyTokenStorage, error) {
deadline := time.Now().Add(maxPollDuration)
pollURL := fmt.Sprintf("%s%s?state=%s", a.baseURL, codeBuddyTokenPath, url.QueryEscape(state))
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
}
body, statusCode, err := a.doPollRequest(ctx, pollURL)
if err != nil {
log.Debugf("codebuddy poll: request error: %v", err)
continue
}
if statusCode != http.StatusOK {
log.Debugf("codebuddy poll: unexpected status %d", statusCode)
continue
}
var result pollResponse
if err := json.Unmarshal(body, &result); err != nil {
continue
}
switch result.Code {
case codeSuccess:
if result.Data == nil {
return nil, fmt.Errorf("%w: empty data in response", ErrTokenFetchFailed)
}
userID, _ := a.DecodeUserID(result.Data.AccessToken)
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
TokenType: result.Data.TokenType,
Domain: result.Data.Domain,
UserID: userID,
Type: "codebuddy",
}, nil
case codeLoginPending:
// continue polling
default:
// TODO: when the CodeBuddy API error code for user denial is known,
// return ErrAccessDenied here instead of ErrTokenFetchFailed.
return nil, fmt.Errorf("%w: server returned code %d: %s", ErrTokenFetchFailed, result.Code, result.Msg)
}
}
return nil, ErrPollingTimeout
}
// DecodeUserID decodes the sub field from a JWT access token as the user ID.
func (a *CodeBuddyAuth) DecodeUserID(accessToken string) (string, error) {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return "", ErrJWTDecodeFailed
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("%w: %v", ErrJWTDecodeFailed, err)
}
if claims.Sub == "" {
return "", fmt.Errorf("%w: sub claim is empty", ErrJWTDecodeFailed)
}
return claims.Sub, nil
}
// RefreshToken exchanges a refresh token for a new access token.
// It calls POST /v2/plugin/auth/token/refresh with the required headers.
func (a *CodeBuddyAuth) RefreshToken(ctx context.Context, accessToken, refreshToken, userID, domain string) (*CodeBuddyTokenStorage, error) {
if domain == "" {
domain = DefaultDomain
}
refreshURL := fmt.Sprintf("%s%s", a.baseURL, codeBuddyRefreshPath)
body := []byte("{}")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to create refresh request: %w", err)
}
requestID := strings.ReplaceAll(uuid.New().String(), "-", "")
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Refresh-Token", refreshToken)
req.Header.Set("X-Auth-Refresh-Source", "plugin")
req.Header.Set("X-Request-ID", requestID)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("User-Agent", UserAgent)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("codebuddy: refresh request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("codebuddy refresh: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("codebuddy: failed to read refresh response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("codebuddy: refresh token rejected (status %d)", resp.StatusCode)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("codebuddy: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data *struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresIn int64 `json:"expiresIn"`
RefreshExpiresIn int64 `json:"refreshExpiresIn"`
TokenType string `json:"tokenType"`
Domain string `json:"domain"`
} `json:"data"`
}
if err = json.Unmarshal(bodyBytes, &result); err != nil {
return nil, fmt.Errorf("codebuddy: failed to parse refresh response: %w", err)
}
if result.Code != codeSuccess {
return nil, fmt.Errorf("codebuddy: refresh failed with code %d: %s", result.Code, result.Msg)
}
if result.Data == nil {
return nil, fmt.Errorf("codebuddy: empty data in refresh response")
}
newUserID, _ := a.DecodeUserID(result.Data.AccessToken)
if newUserID == "" {
newUserID = userID
}
tokenDomain := result.Data.Domain
if tokenDomain == "" {
tokenDomain = domain
}
return &CodeBuddyTokenStorage{
AccessToken: result.Data.AccessToken,
RefreshToken: result.Data.RefreshToken,
ExpiresIn: result.Data.ExpiresIn,
RefreshExpiresIn: result.Data.RefreshExpiresIn,
TokenType: result.Data.TokenType,
Domain: tokenDomain,
UserID: newUserID,
Type: "codebuddy",
}, nil
}
func (a *CodeBuddyAuth) applyPollHeaders(req *http.Request) {
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("User-Agent", UserAgent)
req.Header.Set("X-Requested-With", "XMLHttpRequest")
req.Header.Set("X-No-Authorization", "true")
req.Header.Set("X-No-User-Id", "true")
req.Header.Set("X-No-Enterprise-Id", "true")
req.Header.Set("X-No-Department-Info", "true")
req.Header.Set("X-Product", "SaaS")
}

View File

@@ -0,0 +1,285 @@
package codebuddy
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
// newTestAuth creates a CodeBuddyAuth pointing at the given test server.
func newTestAuth(serverURL string) *CodeBuddyAuth {
return &CodeBuddyAuth{
httpClient: http.DefaultClient,
baseURL: serverURL,
}
}
// fakeJWT builds a minimal JWT with the given sub claim for testing.
func fakeJWT(sub string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
payload, _ := json.Marshal(map[string]any{"sub": sub, "iat": 1234567890})
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return header + "." + encodedPayload + ".sig"
}
// --- FetchAuthState tests ---
func TestFetchAuthState_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyStatePath {
t.Errorf("expected path %s, got %s", codeBuddyStatePath, got)
}
if got := r.URL.Query().Get("platform"); got != "CLI" {
t.Errorf("expected platform=CLI, got %s", got)
}
if got := r.Header.Get("User-Agent"); got != UserAgent {
t.Errorf("expected User-Agent %s, got %s", UserAgent, got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "test-state-abc",
"authUrl": "https://example.com/login?state=test-state-abc",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
result, err := auth.FetchAuthState(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.State != "test-state-abc" {
t.Errorf("expected state 'test-state-abc', got '%s'", result.State)
}
if result.AuthURL != "https://example.com/login?state=test-state-abc" {
t.Errorf("unexpected authURL: %s", result.AuthURL)
}
}
func TestFetchAuthState_NonOKStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("internal error"))
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-200 status")
}
}
func TestFetchAuthState_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 10001,
"msg": "rate limited",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for non-zero code")
}
}
func TestFetchAuthState_MissingData(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"state": "",
"authUrl": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.FetchAuthState(context.Background())
if err == nil {
t.Fatal("expected error for empty state/authUrl")
}
}
// --- RefreshToken tests ---
func TestRefreshToken_Success(t *testing.T) {
newAccessToken := fakeJWT("refreshed-user-456")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST, got %s", r.Method)
}
if got := r.URL.Path; got != codeBuddyRefreshPath {
t.Errorf("expected path %s, got %s", codeBuddyRefreshPath, got)
}
if got := r.Header.Get("X-Refresh-Token"); got != "old-refresh-token" {
t.Errorf("expected X-Refresh-Token 'old-refresh-token', got '%s'", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer old-access-token" {
t.Errorf("expected Authorization 'Bearer old-access-token', got '%s'", got)
}
if got := r.Header.Get("X-User-Id"); got != "user-123" {
t.Errorf("expected X-User-Id 'user-123', got '%s'", got)
}
if got := r.Header.Get("X-Domain"); got != "custom.domain.com" {
t.Errorf("expected X-Domain 'custom.domain.com', got '%s'", got)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": newAccessToken,
"refreshToken": "new-refresh-token",
"expiresIn": 3600,
"refreshExpiresIn": 86400,
"tokenType": "bearer",
"domain": "custom.domain.com",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "old-access-token", "old-refresh-token", "user-123", "custom.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.AccessToken != newAccessToken {
t.Errorf("expected new access token, got '%s'", storage.AccessToken)
}
if storage.RefreshToken != "new-refresh-token" {
t.Errorf("expected 'new-refresh-token', got '%s'", storage.RefreshToken)
}
if storage.UserID != "refreshed-user-456" {
t.Errorf("expected userID 'refreshed-user-456', got '%s'", storage.UserID)
}
if storage.ExpiresIn != 3600 {
t.Errorf("expected expiresIn 3600, got %d", storage.ExpiresIn)
}
if storage.RefreshExpiresIn != 86400 {
t.Errorf("expected refreshExpiresIn 86400, got %d", storage.RefreshExpiresIn)
}
if storage.Domain != "custom.domain.com" {
t.Errorf("expected domain 'custom.domain.com', got '%s'", storage.Domain)
}
if storage.Type != "codebuddy" {
t.Errorf("expected type 'codebuddy', got '%s'", storage.Type)
}
}
func TestRefreshToken_DefaultDomain(t *testing.T) {
var receivedDomain string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedDomain = r.Header.Get("X-Domain")
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": fakeJWT("user-1"),
"refreshToken": "rt",
"expiresIn": 3600,
"tokenType": "bearer",
"domain": DefaultDomain,
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if receivedDomain != DefaultDomain {
t.Errorf("expected default domain '%s', got '%s'", DefaultDomain, receivedDomain)
}
}
func TestRefreshToken_Unauthorized(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 401 response")
}
}
func TestRefreshToken_Forbidden(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for 403 response")
}
}
func TestRefreshToken_APIErrorCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 40001,
"msg": "invalid refresh token",
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
_, err := auth.RefreshToken(context.Background(), "at", "rt", "uid", "d")
if err == nil {
t.Fatal("expected error for non-zero API code")
}
}
func TestRefreshToken_FallbackUserIDAndDomain(t *testing.T) {
// When the new access token cannot be decoded for userID, it should fall back to the provided one.
// When the response domain is empty, it should fall back to the request domain.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"code": 0,
"msg": "ok",
"data": map[string]any{
"accessToken": "not-a-valid-jwt",
"refreshToken": "new-rt",
"expiresIn": 7200,
"tokenType": "bearer",
"domain": "",
},
})
}))
defer srv.Close()
auth := newTestAuth(srv.URL)
storage, err := auth.RefreshToken(context.Background(), "at", "rt", "original-uid", "original.domain.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if storage.UserID != "original-uid" {
t.Errorf("expected fallback userID 'original-uid', got '%s'", storage.UserID)
}
if storage.Domain != "original.domain.com" {
t.Errorf("expected fallback domain 'original.domain.com', got '%s'", storage.Domain)
}
}

View File

@@ -0,0 +1,22 @@
package codebuddy_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
)
func TestDecodeUserID_ValidJWT(t *testing.T) {
// JWT payload: {"sub":"test-user-id-123","iat":1234567890}
// base64url encode: eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ
token := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXItaWQtMTIzIiwiaWF0IjoxMjM0NTY3ODkwfQ.sig"
auth := codebuddy.NewCodeBuddyAuth(nil)
userID, err := auth.DecodeUserID(token)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if userID != "test-user-id-123" {
t.Errorf("expected 'test-user-id-123', got '%s'", userID)
}
}

View File

@@ -0,0 +1,25 @@
package codebuddy
import "errors"
var (
ErrPollingTimeout = errors.New("codebuddy: polling timeout, user did not authorize in time")
ErrAccessDenied = errors.New("codebuddy: access denied by user")
ErrTokenFetchFailed = errors.New("codebuddy: failed to fetch token from server")
ErrJWTDecodeFailed = errors.New("codebuddy: failed to decode JWT token")
)
func GetUserFriendlyMessage(err error) string {
switch {
case errors.Is(err, ErrPollingTimeout):
return "Authentication timed out. Please try again."
case errors.Is(err, ErrAccessDenied):
return "Access denied. Please try again and approve the login request."
case errors.Is(err, ErrJWTDecodeFailed):
return "Failed to decode token. Please try logging in again."
case errors.Is(err, ErrTokenFetchFailed):
return "Failed to fetch token from server. Please try again."
default:
return "Authentication failed: " + err.Error()
}
}

View File

@@ -0,0 +1,65 @@
// Package codebuddy provides authentication and token management functionality
// for CodeBuddy AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the CodeBuddy API.
package codebuddy
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// CodeBuddyTokenStorage stores OAuth token information for CodeBuddy API authentication.
// It maintains compatibility with the existing auth system while adding CodeBuddy-specific fields
// for managing access tokens and user account information.
type CodeBuddyTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the number of seconds until the access token expires.
ExpiresIn int64 `json:"expires_in"`
// RefreshExpiresIn is the number of seconds until the refresh token expires.
RefreshExpiresIn int64 `json:"refresh_expires_in,omitempty"`
// TokenType is the type of token, typically "bearer".
TokenType string `json:"token_type"`
// Domain is the CodeBuddy service domain/region.
Domain string `json:"domain"`
// UserID is the user ID associated with this token.
UserID string `json:"user_id"`
// Type indicates the authentication provider type, always "codebuddy" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the CodeBuddy token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (s *CodeBuddyTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
s.Type = "codebuddy"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
f, err := os.OpenFile(authFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(s); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,33 @@
package cursor
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Cursor credentials.
// Priority: explicit label > auto-generated from JWT sub hash.
// If both label and subHash are empty, falls back to "cursor.json".
func CredentialFileName(label, subHash string) string {
label = strings.TrimSpace(label)
subHash = strings.TrimSpace(subHash)
if label != "" {
return fmt.Sprintf("cursor.%s.json", label)
}
if subHash != "" {
return fmt.Sprintf("cursor.%s.json", subHash)
}
return "cursor.json"
}
// DisplayLabel returns a human-readable label for the Cursor account.
func DisplayLabel(label, subHash string) string {
label = strings.TrimSpace(label)
if label != "" {
return "Cursor " + label
}
if subHash != "" {
return "Cursor " + subHash
}
return "Cursor User"
}

View File

@@ -0,0 +1,249 @@
// Package cursor implements Cursor OAuth PKCE authentication and token refresh.
package cursor
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
)
const (
CursorLoginURL = "https://cursor.com/loginDeepControl"
CursorPollURL = "https://api2.cursor.sh/auth/poll"
CursorRefreshURL = "https://api2.cursor.sh/auth/exchange_user_api_key"
pollMaxAttempts = 150
pollBaseDelay = 1 * time.Second
pollMaxDelay = 10 * time.Second
pollBackoffMultiply = 1.2
maxConsecutiveErrors = 10
)
// AuthParams holds the PKCE parameters for Cursor login.
type AuthParams struct {
Verifier string
Challenge string
UUID string
LoginURL string
}
// TokenPair holds the access and refresh tokens from Cursor.
type TokenPair struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
// GeneratePKCE creates a PKCE verifier and challenge pair.
func GeneratePKCE() (verifier, challenge string, err error) {
verifierBytes := make([]byte, 96)
if _, err = rand.Read(verifierBytes); err != nil {
return "", "", fmt.Errorf("cursor: failed to generate PKCE verifier: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// GenerateAuthParams creates the full set of auth params for Cursor login.
func GenerateAuthParams() (*AuthParams, error) {
verifier, challenge, err := GeneratePKCE()
if err != nil {
return nil, err
}
uuidBytes := make([]byte, 16)
if _, err = rand.Read(uuidBytes); err != nil {
return nil, fmt.Errorf("cursor: failed to generate UUID: %w", err)
}
uuid := fmt.Sprintf("%x-%x-%x-%x-%x",
uuidBytes[0:4], uuidBytes[4:6], uuidBytes[6:8], uuidBytes[8:10], uuidBytes[10:16])
loginURL := fmt.Sprintf("%s?challenge=%s&uuid=%s&mode=login&redirectTarget=cli",
CursorLoginURL, challenge, uuid)
return &AuthParams{
Verifier: verifier,
Challenge: challenge,
UUID: uuid,
LoginURL: loginURL,
}, nil
}
// PollForAuth polls the Cursor auth endpoint until the user completes login.
func PollForAuth(ctx context.Context, uuid, verifier string) (*TokenPair, error) {
delay := pollBaseDelay
consecutiveErrors := 0
client := &http.Client{Timeout: 10 * time.Second}
for attempt := 0; attempt < pollMaxAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
url := fmt.Sprintf("%s?uuid=%s&verifier=%s", CursorPollURL, uuid, verifier)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("cursor: failed to create poll request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
consecutiveErrors++
if consecutiveErrors >= maxConsecutiveErrors {
return nil, fmt.Errorf("cursor: too many consecutive poll errors (last: %v)", err)
}
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Still waiting for user to authorize
consecutiveErrors = 0
delay = minDuration(time.Duration(float64(delay)*pollBackoffMultiply), pollMaxDelay)
continue
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse auth response: %w", err)
}
return &tokens, nil
}
return nil, fmt.Errorf("cursor: poll failed with status %d: %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("cursor: authentication polling timeout (waited ~%.0f seconds)",
float64(pollMaxAttempts)*pollMaxDelay.Seconds()/2)
}
// RefreshToken refreshes a Cursor access token using the refresh token.
func RefreshToken(ctx context.Context, refreshToken string) (*TokenPair, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, CursorRefreshURL,
strings.NewReader("{}"))
if err != nil {
return nil, fmt.Errorf("cursor: failed to create refresh request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+refreshToken)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("cursor: token refresh request failed: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("cursor: token refresh failed (status %d): %s", resp.StatusCode, string(body))
}
var tokens TokenPair
if err := json.Unmarshal(body, &tokens); err != nil {
return nil, fmt.Errorf("cursor: failed to parse refresh response: %w", err)
}
// Keep original refresh token if not returned
if tokens.RefreshToken == "" {
tokens.RefreshToken = refreshToken
}
return &tokens, nil
}
// ParseJWTSub extracts the "sub" claim from a Cursor JWT access token.
// Cursor JWTs contain "sub" like "auth0|user_XXXX" which uniquely identifies
// the account. Returns empty string if parsing fails.
func ParseJWTSub(token string) string {
decoded := decodeJWTPayload(token)
if decoded == nil {
return ""
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
return claims.Sub
}
// SubToShortHash converts a JWT sub claim to a short hex hash for use in filenames.
// e.g. "auth0|user_2x..." → "a3f8b2c1"
func SubToShortHash(sub string) string {
if sub == "" {
return ""
}
h := sha256.Sum256([]byte(sub))
return fmt.Sprintf("%x", h[:4]) // 8 hex chars
}
// decodeJWTPayload decodes the payload (middle) part of a JWT.
func decodeJWTPayload(token string) []byte {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
payload = strings.ReplaceAll(payload, "-", "+")
payload = strings.ReplaceAll(payload, "_", "/")
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil
}
return decoded
}
// GetTokenExpiry extracts the JWT expiry from an access token with a 5-minute safety margin.
// Falls back to 1 hour from now if the token can't be parsed.
func GetTokenExpiry(token string) time.Time {
decoded := decodeJWTPayload(token)
if decoded == nil {
return time.Now().Add(1 * time.Hour)
}
var claims struct {
Exp float64 `json:"exp"`
}
if err := json.Unmarshal(decoded, &claims); err != nil || claims.Exp == 0 {
return time.Now().Add(1 * time.Hour)
}
sec, frac := math.Modf(claims.Exp)
expiry := time.Unix(int64(sec), int64(frac*1e9))
// Subtract 5-minute safety margin
return expiry.Add(-5 * time.Minute)
}
func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,84 @@
package proto
import (
"encoding/binary"
"encoding/json"
"fmt"
)
const (
// ConnectEndStreamFlag marks the end-of-stream frame (trailers).
ConnectEndStreamFlag byte = 0x02
// ConnectCompressionFlag indicates the payload is compressed (not supported).
ConnectCompressionFlag byte = 0x01
// ConnectFrameHeaderSize is the fixed 5-byte frame header.
ConnectFrameHeaderSize = 5
)
// FrameConnectMessage wraps a protobuf payload in a Connect frame.
// Frame format: [1 byte flags][4 bytes payload length (big-endian)][payload]
func FrameConnectMessage(data []byte, flags byte) []byte {
frame := make([]byte, ConnectFrameHeaderSize+len(data))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(data)))
copy(frame[5:], data)
return frame
}
// ParseConnectFrame extracts one frame from a buffer.
// Returns (flags, payload, bytesConsumed, ok).
// ok is false when the buffer is too short for a complete frame.
func ParseConnectFrame(buf []byte) (flags byte, payload []byte, consumed int, ok bool) {
if len(buf) < ConnectFrameHeaderSize {
return 0, nil, 0, false
}
flags = buf[0]
length := binary.BigEndian.Uint32(buf[1:5])
total := ConnectFrameHeaderSize + int(length)
if len(buf) < total {
return 0, nil, 0, false
}
return flags, buf[5:total], total, true
}
// ConnectError is a structured error from the Connect protocol end-of-stream trailer.
// The Code field contains the server-defined error code (e.g. gRPC standard codes
// like "resource_exhausted", "unauthenticated", "permission_denied", "unavailable").
type ConnectError struct {
Code string // server-defined error code
Message string // human-readable error description
}
func (e *ConnectError) Error() string {
return fmt.Sprintf("Connect error %s: %s", e.Code, e.Message)
}
// ParseConnectEndStream parses a Connect end-of-stream frame payload (JSON).
// Returns nil if there is no error in the trailer.
// On error, returns a *ConnectError with the server's error code and message.
func ParseConnectEndStream(data []byte) error {
if len(data) == 0 {
return nil
}
var trailer struct {
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
if err := json.Unmarshal(data, &trailer); err != nil {
return fmt.Errorf("failed to parse Connect end stream: %w", err)
}
if trailer.Error != nil {
code := trailer.Error.Code
if code == "" {
code = "unknown"
}
msg := trailer.Error.Message
if msg == "" {
msg = "Unknown error"
}
return &ConnectError{Code: code, Message: msg}
}
return nil
}

View File

@@ -0,0 +1,564 @@
package proto
import (
"encoding/hex"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
)
// ServerMessageType identifies the kind of decoded server message.
type ServerMessageType int
const (
ServerMsgUnknown ServerMessageType = iota
ServerMsgTextDelta // Text content delta
ServerMsgThinkingDelta // Thinking/reasoning delta
ServerMsgThinkingCompleted // Thinking completed
ServerMsgKvGetBlob // Server wants a blob
ServerMsgKvSetBlob // Server wants to store a blob
ServerMsgExecRequestCtx // Server requests context (tools, etc.)
ServerMsgExecMcpArgs // Server wants MCP tool execution
ServerMsgExecShellArgs // Rejected: shell command
ServerMsgExecReadArgs // Rejected: file read
ServerMsgExecWriteArgs // Rejected: file write
ServerMsgExecDeleteArgs // Rejected: file delete
ServerMsgExecLsArgs // Rejected: directory listing
ServerMsgExecGrepArgs // Rejected: grep search
ServerMsgExecFetchArgs // Rejected: HTTP fetch
ServerMsgExecDiagnostics // Respond with empty diagnostics
ServerMsgExecShellStream // Rejected: shell stream
ServerMsgExecBgShellSpawn // Rejected: background shell
ServerMsgExecWriteShellStdin // Rejected: write shell stdin
ServerMsgExecOther // Other exec types (respond with empty)
ServerMsgTurnEnded // Turn has ended (no more output)
ServerMsgHeartbeat // Server heartbeat
ServerMsgTokenDelta // Token usage delta
ServerMsgCheckpoint // Conversation checkpoint update
)
// DecodedServerMessage holds parsed data from an AgentServerMessage.
type DecodedServerMessage struct {
Type ServerMessageType
// For text/thinking deltas
Text string
// For KV messages
KvId uint32
BlobId []byte // hex-encoded blob ID
BlobData []byte // for setBlobArgs
// For exec messages
ExecMsgId uint32
ExecId string
// For MCP args
McpToolName string
McpToolCallId string
McpArgs map[string][]byte // arg name -> protobuf-encoded value
// For rejection context
Path string
Command string
WorkingDirectory string
Url string
// For other exec - the raw field number for building a response
ExecFieldNumber int
// For TokenDeltaUpdate
TokenDelta int64
// For conversation checkpoint update (raw bytes, not decoded)
CheckpointData []byte
}
// DecodeAgentServerMessage parses an AgentServerMessage and returns
// a structured representation of the first meaningful message found.
func DecodeAgentServerMessage(data []byte) (*DecodedServerMessage, error) {
msg := &DecodedServerMessage{Type: ServerMsgUnknown}
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return msg, fmt.Errorf("invalid tag")
}
data = data[n:]
switch typ {
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return msg, fmt.Errorf("invalid bytes field %d", num)
}
data = data[n:]
// Debug: log top-level ASM fields
log.Debugf("DecodeAgentServerMessage: found ASM field %d, len=%d", num, len(val))
switch num {
case ASM_InteractionUpdate:
log.Debugf("DecodeAgentServerMessage: calling decodeInteractionUpdate")
decodeInteractionUpdate(val, msg)
case ASM_ExecServerMessage:
log.Debugf("DecodeAgentServerMessage: calling decodeExecServerMessage")
decodeExecServerMessage(val, msg)
case ASM_KvServerMessage:
decodeKvServerMessage(val, msg)
case ASM_ConversationCheckpoint:
msg.Type = ServerMsgCheckpoint
msg.CheckpointData = append([]byte(nil), val...) // copy raw bytes
log.Debugf("DecodeAgentServerMessage: captured checkpoint %d bytes", len(val))
}
case protowire.VarintType:
_, n := protowire.ConsumeVarint(data)
if n < 0 {
return msg, fmt.Errorf("invalid varint field %d", num)
}
data = data[n:]
default:
// Skip unknown wire types
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return msg, fmt.Errorf("invalid field %d", num)
}
data = data[n:]
}
}
return msg, nil
}
func decodeInteractionUpdate(data []byte, msg *DecodedServerMessage) {
log.Debugf("decodeInteractionUpdate: input len=%d, hex=%x", len(data), data)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid tag, remaining=%x", data)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field=%d wire=%d remaining=%d bytes", num, typ, len(data))
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
log.Debugf("decodeInteractionUpdate: invalid bytes field %d", num)
return
}
data = data[n:]
log.Debugf("decodeInteractionUpdate: field %d content len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case IU_TextDelta:
msg.Type = ServerMsgTextDelta
msg.Text = decodeStringField(val, TDU_Text)
log.Debugf("decodeInteractionUpdate: TextDelta text=%q", msg.Text)
case IU_ThinkingDelta:
msg.Type = ServerMsgThinkingDelta
msg.Text = decodeStringField(val, TKD_Text)
log.Debugf("decodeInteractionUpdate: ThinkingDelta text=%q", msg.Text)
case IU_ThinkingCompleted:
msg.Type = ServerMsgThinkingCompleted
log.Debugf("decodeInteractionUpdate: ThinkingCompleted")
case 2:
// tool_call_started - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallStarted (ignored)")
case 3:
// tool_call_completed - ignore but log
log.Debugf("decodeInteractionUpdate: ToolCallCompleted (ignored)")
case 8:
// token_delta - extract token count
msg.Type = ServerMsgTokenDelta
msg.TokenDelta = decodeVarintField(val, 1)
log.Debugf("decodeInteractionUpdate: TokenDeltaUpdate tokens=%d", msg.TokenDelta)
case 13:
// heartbeat from server
msg.Type = ServerMsgHeartbeat
case 14:
// turn_ended - critical: model finished generating
msg.Type = ServerMsgTurnEnded
log.Debugf("decodeInteractionUpdate: TurnEndedUpdate - stream should end")
case 16:
// step_started - ignore
log.Debugf("decodeInteractionUpdate: StepStartedUpdate (ignored)")
case 17:
// step_completed - ignore
log.Debugf("decodeInteractionUpdate: StepCompletedUpdate (ignored)")
default:
log.Debugf("decodeInteractionUpdate: unknown field %d", num)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeKvServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == KSM_Id {
msg.KvId = uint32(val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case KSM_GetBlobArgs:
msg.Type = ServerMsgKvGetBlob
msg.BlobId = decodeBytesField(val, GBA_BlobId)
case KSM_SetBlobArgs:
msg.Type = ServerMsgKvSetBlob
decodeSetBlobArgs(val, msg)
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeSetBlobArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SBA_BlobId:
msg.BlobId = val
case SBA_BlobData:
msg.BlobData = val
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeExecServerMessage(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
switch typ {
case protowire.VarintType:
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return
}
data = data[n:]
if num == ESM_Id {
msg.ExecMsgId = uint32(val)
log.Debugf("decodeExecServerMessage: ESM_Id = %d", val)
}
case protowire.BytesType:
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
// Debug: log all fields found in ExecServerMessage
log.Debugf("decodeExecServerMessage: found field %d, len=%d, first 20 bytes: %x", num, len(val), val[:min(20, len(val))])
switch num {
case ESM_ExecId:
msg.ExecId = string(val)
log.Debugf("decodeExecServerMessage: ESM_ExecId = %q", msg.ExecId)
case ESM_RequestContextArgs:
msg.Type = ServerMsgExecRequestCtx
case ESM_McpArgs:
msg.Type = ServerMsgExecMcpArgs
decodeMcpArgs(val, msg)
case ESM_ShellArgs:
msg.Type = ServerMsgExecShellArgs
decodeShellArgs(val, msg)
case ESM_ShellStreamArgs:
msg.Type = ServerMsgExecShellStream
decodeShellArgs(val, msg)
case ESM_ReadArgs:
msg.Type = ServerMsgExecReadArgs
msg.Path = decodeStringField(val, RA_Path)
case ESM_WriteArgs:
msg.Type = ServerMsgExecWriteArgs
msg.Path = decodeStringField(val, WA_Path)
case ESM_DeleteArgs:
msg.Type = ServerMsgExecDeleteArgs
msg.Path = decodeStringField(val, DA_Path)
case ESM_LsArgs:
msg.Type = ServerMsgExecLsArgs
msg.Path = decodeStringField(val, LA_Path)
case ESM_GrepArgs:
msg.Type = ServerMsgExecGrepArgs
case ESM_FetchArgs:
msg.Type = ServerMsgExecFetchArgs
msg.Url = decodeStringField(val, FA_Url)
case ESM_DiagnosticsArgs:
msg.Type = ServerMsgExecDiagnostics
case ESM_BackgroundShellSpawn:
msg.Type = ServerMsgExecBgShellSpawn
decodeShellArgs(val, msg) // same structure
case ESM_WriteShellStdinArgs:
msg.Type = ServerMsgExecWriteShellStdin
default:
// Unknown exec types - only set if we haven't identified the type yet
// (other fields like span_context (19) come after the exec type field)
if msg.Type == ServerMsgUnknown {
msg.Type = ServerMsgExecOther
msg.ExecFieldNumber = int(num)
}
}
default:
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMcpArgs(data []byte, msg *DecodedServerMessage) {
msg.McpArgs = make(map[string][]byte)
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case MCA_Name:
msg.McpToolName = string(val)
case MCA_Args:
// Map entries are encoded as submessages with key=1, value=2
decodeMapEntry(val, msg.McpArgs)
case MCA_ToolCallId:
msg.McpToolCallId = string(val)
case MCA_ToolName:
// ToolName takes precedence if present
if msg.McpToolName == "" || string(val) != "" {
msg.McpToolName = string(val)
}
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
func decodeMapEntry(data []byte, m map[string][]byte) {
var key string
var value []byte
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
if num == 1 {
key = string(val)
} else if num == 2 {
value = append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
if key != "" {
m[key] = value
}
}
func decodeShellArgs(data []byte, msg *DecodedServerMessage) {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return
}
data = data[n:]
switch num {
case SHA_Command:
msg.Command = string(val)
case SHA_WorkingDirectory:
msg.WorkingDirectory = string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return
}
data = data[n:]
}
}
}
// --- Helper decoders ---
// decodeStringField extracts a string from the first matching field in a submessage.
func decodeStringField(data []byte, targetField protowire.Number) string {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return ""
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return ""
}
data = data[n:]
if num == targetField {
return string(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return ""
}
data = data[n:]
}
}
return ""
}
// decodeBytesField extracts bytes from the first matching field in a submessage.
func decodeBytesField(data []byte, targetField protowire.Number) []byte {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return nil
}
data = data[n:]
if typ == protowire.BytesType {
val, n := protowire.ConsumeBytes(data)
if n < 0 {
return nil
}
data = data[n:]
if num == targetField {
return append([]byte(nil), val...)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return nil
}
data = data[n:]
}
}
return nil
}
// decodeVarintField extracts an int64 from the first matching varint field in a submessage.
func decodeVarintField(data []byte, targetField protowire.Number) int64 {
for len(data) > 0 {
num, typ, n := protowire.ConsumeTag(data)
if n < 0 {
return 0
}
data = data[n:]
if typ == protowire.VarintType {
val, n := protowire.ConsumeVarint(data)
if n < 0 {
return 0
}
data = data[n:]
if num == targetField {
return int64(val)
}
} else {
n := protowire.ConsumeFieldValue(num, typ, data)
if n < 0 {
return 0
}
data = data[n:]
}
}
return 0
}
// BlobIdHex returns the hex string of a blob ID for use as a map key.
func BlobIdHex(blobId []byte) string {
return hex.EncodeToString(blobId)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,664 @@
// Package proto provides protobuf encoding for Cursor's gRPC API,
// using dynamicpb with the embedded FileDescriptorProto from agent.proto.
// This mirrors the cursor-auth TS plugin's use of @bufbuild/protobuf create()+toBinary().
package proto
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/structpb"
)
// --- Public types ---
// RunRequestParams holds all data needed to build an AgentRunRequest.
type RunRequestParams struct {
ModelId string
SystemPrompt string
UserText string
MessageId string
ConversationId string
Images []ImageData
Turns []TurnData
McpTools []McpToolDef
BlobStore map[string][]byte // hex(sha256) -> data, populated during encoding
RawCheckpoint []byte // if non-nil, use as conversation_state directly (from server checkpoint)
}
type ImageData struct {
MimeType string
Data []byte
}
type TurnData struct {
UserText string
AssistantText string
}
type McpToolDef struct {
Name string
Description string
InputSchema json.RawMessage
}
// --- Helper: create a dynamic message and set fields ---
func newMsg(name string) *dynamicpb.Message {
return dynamicpb.NewMessage(Msg(name))
}
func field(msg *dynamicpb.Message, name string) protoreflect.FieldDescriptor {
return msg.Descriptor().Fields().ByName(protoreflect.Name(name))
}
func setStr(msg *dynamicpb.Message, name, val string) {
if val != "" {
msg.Set(field(msg, name), protoreflect.ValueOfString(val))
}
}
func setBytes(msg *dynamicpb.Message, name string, val []byte) {
if len(val) > 0 {
msg.Set(field(msg, name), protoreflect.ValueOfBytes(val))
}
}
func setUint32(msg *dynamicpb.Message, name string, val uint32) {
msg.Set(field(msg, name), protoreflect.ValueOfUint32(val))
}
func setBool(msg *dynamicpb.Message, name string, val bool) {
msg.Set(field(msg, name), protoreflect.ValueOfBool(val))
}
func setMsg(msg *dynamicpb.Message, name string, sub *dynamicpb.Message) {
msg.Set(field(msg, name), protoreflect.ValueOfMessage(sub.ProtoReflect()))
}
func marshal(msg *dynamicpb.Message) []byte {
b, err := proto.Marshal(msg)
if err != nil {
panic("cursor proto marshal: " + err.Error())
}
return b
}
// --- Encode functions mirroring cursor-fetch.ts ---
// EncodeHeartbeat returns an encoded AgentClientMessage with clientHeartbeat.
// Mirrors: create(AgentClientMessageSchema, { message: { case: 'clientHeartbeat', value: create(ClientHeartbeatSchema, {}) } })
func EncodeHeartbeat() []byte {
hb := newMsg("ClientHeartbeat")
acm := newMsg("AgentClientMessage")
setMsg(acm, "client_heartbeat", hb)
return marshal(acm)
}
// EncodeRunRequest builds a full AgentClientMessage wrapping an AgentRunRequest.
// Mirrors buildCursorRequest() in cursor-fetch.ts.
// If p.RawCheckpoint is set, it is used directly as the conversation_state bytes
// (from a previous conversation_checkpoint_update), skipping manual turn construction.
func EncodeRunRequest(p *RunRequestParams) []byte {
if p.RawCheckpoint != nil {
return encodeRunRequestWithCheckpoint(p)
}
if p.BlobStore == nil {
p.BlobStore = make(map[string][]byte)
}
// --- Conversation turns ---
// Each turn is serialized as bytes (ConversationTurnStructure → bytes)
var turnBytes [][]byte
for _, turn := range p.Turns {
// UserMessage for this turn
um := newMsg("UserMessage")
setStr(um, "text", turn.UserText)
setStr(um, "message_id", generateId())
umBytes := marshal(um)
// Steps (assistant response)
var stepBytes [][]byte
if turn.AssistantText != "" {
am := newMsg("AssistantMessage")
setStr(am, "text", turn.AssistantText)
step := newMsg("ConversationStep")
setMsg(step, "assistant_message", am)
stepBytes = append(stepBytes, marshal(step))
}
// AgentConversationTurnStructure (fields are bytes, not submessages)
agentTurn := newMsg("AgentConversationTurnStructure")
setBytes(agentTurn, "user_message", umBytes)
for _, sb := range stepBytes {
stepsField := field(agentTurn, "steps")
list := agentTurn.Mutable(stepsField).List()
list.Append(protoreflect.ValueOfBytes(sb))
}
// ConversationTurnStructure (oneof turn → agentConversationTurn)
cts := newMsg("ConversationTurnStructure")
setMsg(cts, "agent_conversation_turn", agentTurn)
turnBytes = append(turnBytes, marshal(cts))
}
// --- System prompt blob ---
systemJSON, _ := json.Marshal(map[string]string{"role": "system", "content": p.SystemPrompt})
blobId := sha256Sum(systemJSON)
p.BlobStore[hex.EncodeToString(blobId)] = systemJSON
// --- ConversationStateStructure ---
css := newMsg("ConversationStateStructure")
// rootPromptMessagesJson: repeated bytes
rootField := field(css, "root_prompt_messages_json")
rootList := css.Mutable(rootField).List()
rootList.Append(protoreflect.ValueOfBytes(blobId))
// turns: repeated bytes (field 8) + turns_old (field 2) for compatibility
turnsField := field(css, "turns")
turnsList := css.Mutable(turnsField).List()
for _, tb := range turnBytes {
turnsList.Append(protoreflect.ValueOfBytes(tb))
}
turnsOldField := field(css, "turns_old")
if turnsOldField != nil {
turnsOldList := css.Mutable(turnsOldField).List()
for _, tb := range turnBytes {
turnsOldList.Append(protoreflect.ValueOfBytes(tb))
}
}
// --- UserMessage (current) ---
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
// Images via SelectedContext
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// --- UserMessageAction ---
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
// --- ConversationAction ---
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
// --- ModelDetails ---
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// --- AgentRunRequest ---
arr := newMsg("AgentRunRequest")
setMsg(arr, "conversation_state", css)
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
// --- AgentClientMessage ---
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// encodeRunRequestWithCheckpoint builds an AgentClientMessage using a raw checkpoint
// as conversation_state. The checkpoint bytes are embedded directly without deserialization.
func encodeRunRequestWithCheckpoint(p *RunRequestParams) []byte {
// Build UserMessage
userMessage := newMsg("UserMessage")
setStr(userMessage, "text", p.UserText)
setStr(userMessage, "message_id", p.MessageId)
if len(p.Images) > 0 {
sc := newMsg("SelectedContext")
imgsField := field(sc, "selected_images")
imgsList := sc.Mutable(imgsField).List()
for _, img := range p.Images {
si := newMsg("SelectedImage")
setStr(si, "uuid", generateId())
setStr(si, "mime_type", img.MimeType)
setBytes(si, "data", img.Data)
imgsList.Append(protoreflect.ValueOfMessage(si.ProtoReflect()))
}
setMsg(userMessage, "selected_context", sc)
}
// Build ConversationAction with UserMessageAction
uma := newMsg("UserMessageAction")
setMsg(uma, "user_message", userMessage)
ca := newMsg("ConversationAction")
setMsg(ca, "user_message_action", uma)
caBytes := marshal(ca)
// Build ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
mdBytes := marshal(md)
// Build McpTools
var mcpToolsBytes []byte
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
mcpToolsBytes = marshal(mcpTools)
}
// Manually assemble AgentRunRequest using protowire to embed raw checkpoint
var arrBuf []byte
// field 1: conversation_state = raw checkpoint bytes (length-delimited)
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationState, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, p.RawCheckpoint)
// field 2: action = ConversationAction
arrBuf = protowire.AppendTag(arrBuf, ARR_Action, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, caBytes)
// field 3: model_details = ModelDetails
arrBuf = protowire.AppendTag(arrBuf, ARR_ModelDetails, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mdBytes)
// field 4: mcp_tools = McpTools
if len(mcpToolsBytes) > 0 {
arrBuf = protowire.AppendTag(arrBuf, ARR_McpTools, protowire.BytesType)
arrBuf = protowire.AppendBytes(arrBuf, mcpToolsBytes)
}
// field 5: conversation_id = string
if p.ConversationId != "" {
arrBuf = protowire.AppendTag(arrBuf, ARR_ConversationId, protowire.BytesType)
arrBuf = protowire.AppendString(arrBuf, p.ConversationId)
}
// Wrap in AgentClientMessage field 1 (run_request)
var acmBuf []byte
acmBuf = protowire.AppendTag(acmBuf, ACM_RunRequest, protowire.BytesType)
acmBuf = protowire.AppendBytes(acmBuf, arrBuf)
log.Debugf("cursor encode: built RunRequest with checkpoint (%d bytes), total=%d bytes", len(p.RawCheckpoint), len(acmBuf))
return acmBuf
}
// ResumeRequestParams holds data for a ResumeAction request.
type ResumeRequestParams struct {
ModelId string
ConversationId string
McpTools []McpToolDef
}
// EncodeResumeRequest builds an AgentClientMessage with ResumeAction.
// Used to resume a conversation by conversation_id without re-sending full history.
func EncodeResumeRequest(p *ResumeRequestParams) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(p.McpTools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// ResumeAction
ra := newMsg("ResumeAction")
setMsg(ra, "request_context", rc)
// ConversationAction with resume_action
ca := newMsg("ConversationAction")
setMsg(ca, "resume_action", ra)
// ModelDetails
md := newMsg("ModelDetails")
setStr(md, "model_id", p.ModelId)
setStr(md, "display_model_id", p.ModelId)
setStr(md, "display_name", p.ModelId)
// AgentRunRequest — no conversation_state needed for resume
arr := newMsg("AgentRunRequest")
setMsg(arr, "action", ca)
setMsg(arr, "model_details", md)
setStr(arr, "conversation_id", p.ConversationId)
// McpTools at top level
if len(p.McpTools) > 0 {
mcpTools := newMsg("McpTools")
toolsField := field(mcpTools, "mcp_tools")
toolsList := mcpTools.Mutable(toolsField).List()
for _, tool := range p.McpTools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
setMsg(arr, "mcp_tools", mcpTools)
}
acm := newMsg("AgentClientMessage")
setMsg(acm, "run_request", arr)
return marshal(acm)
}
// --- KV response encoders ---
// Mirrors handleKvMessage() in cursor-fetch.ts
// EncodeKvGetBlobResult responds to a getBlobArgs request.
func EncodeKvGetBlobResult(kvId uint32, blobData []byte) []byte {
result := newMsg("GetBlobResult")
if blobData != nil {
setBytes(result, "blob_data", blobData)
}
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "get_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// EncodeKvSetBlobResult responds to a setBlobArgs request.
func EncodeKvSetBlobResult(kvId uint32) []byte {
result := newMsg("SetBlobResult")
kvc := newMsg("KvClientMessage")
setUint32(kvc, "id", kvId)
setMsg(kvc, "set_blob_result", result)
acm := newMsg("AgentClientMessage")
setMsg(acm, "kv_client_message", kvc)
return marshal(acm)
}
// --- Exec response encoders ---
// Mirrors handleExecMessage() and sendExec() in cursor-fetch.ts
// EncodeExecRequestContextResult responds to requestContextArgs with tool definitions.
func EncodeExecRequestContextResult(execMsgId uint32, execId string, tools []McpToolDef) []byte {
// RequestContext with tools
rc := newMsg("RequestContext")
if len(tools) > 0 {
toolsField := field(rc, "tools")
toolsList := rc.Mutable(toolsField).List()
for _, tool := range tools {
td := newMsg("McpToolDefinition")
setStr(td, "name", tool.Name)
setStr(td, "description", tool.Description)
if len(tool.InputSchema) > 0 {
setBytes(td, "input_schema", jsonToProtobufValueBytes(tool.InputSchema))
}
setStr(td, "provider_identifier", "proxy")
setStr(td, "tool_name", tool.Name)
toolsList.Append(protoreflect.ValueOfMessage(td.ProtoReflect()))
}
}
// RequestContextSuccess
rcs := newMsg("RequestContextSuccess")
setMsg(rcs, "request_context", rc)
// RequestContextResult (oneof success)
rcr := newMsg("RequestContextResult")
setMsg(rcr, "success", rcs)
return encodeExecClientMsg(execMsgId, execId, "request_context_result", rcr)
}
// EncodeExecMcpResult responds with MCP tool result.
func EncodeExecMcpResult(execMsgId uint32, execId string, content string, isError bool) []byte {
textContent := newMsg("McpTextContent")
setStr(textContent, "text", content)
contentItem := newMsg("McpToolResultContentItem")
setMsg(contentItem, "text", textContent)
success := newMsg("McpSuccess")
contentField := field(success, "content")
contentList := success.Mutable(contentField).List()
contentList.Append(protoreflect.ValueOfMessage(contentItem.ProtoReflect()))
setBool(success, "is_error", isError)
result := newMsg("McpResult")
setMsg(result, "success", success)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// EncodeExecMcpError responds with MCP error.
func EncodeExecMcpError(execMsgId uint32, execId string, errMsg string) []byte {
mcpErr := newMsg("McpError")
setStr(mcpErr, "error", errMsg)
result := newMsg("McpResult")
setMsg(result, "error", mcpErr)
return encodeExecClientMsg(execMsgId, execId, "mcp_result", result)
}
// --- Rejection encoders (mirror handleExecMessage rejections) ---
func EncodeExecReadRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("ReadRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("ReadResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "read_result", result)
}
func EncodeExecShellRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("ShellResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "shell_result", result)
}
func EncodeExecWriteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("WriteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("WriteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "write_result", result)
}
func EncodeExecDeleteRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("DeleteRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("DeleteResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "delete_result", result)
}
func EncodeExecLsRejected(execMsgId uint32, execId string, path, reason string) []byte {
rej := newMsg("LsRejected")
setStr(rej, "path", path)
setStr(rej, "reason", reason)
result := newMsg("LsResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "ls_result", result)
}
func EncodeExecGrepError(execMsgId uint32, execId string, errMsg string) []byte {
grepErr := newMsg("GrepError")
setStr(grepErr, "error", errMsg)
result := newMsg("GrepResult")
setMsg(result, "error", grepErr)
return encodeExecClientMsg(execMsgId, execId, "grep_result", result)
}
func EncodeExecFetchError(execMsgId uint32, execId string, url, errMsg string) []byte {
fetchErr := newMsg("FetchError")
setStr(fetchErr, "url", url)
setStr(fetchErr, "error", errMsg)
result := newMsg("FetchResult")
setMsg(result, "error", fetchErr)
return encodeExecClientMsg(execMsgId, execId, "fetch_result", result)
}
func EncodeExecDiagnosticsResult(execMsgId uint32, execId string) []byte {
result := newMsg("DiagnosticsResult")
return encodeExecClientMsg(execMsgId, execId, "diagnostics_result", result)
}
func EncodeExecBackgroundShellSpawnRejected(execMsgId uint32, execId string, command, workDir, reason string) []byte {
rej := newMsg("ShellRejected")
setStr(rej, "command", command)
setStr(rej, "working_directory", workDir)
setStr(rej, "reason", reason)
result := newMsg("BackgroundShellSpawnResult")
setMsg(result, "rejected", rej)
return encodeExecClientMsg(execMsgId, execId, "background_shell_spawn_result", result)
}
func EncodeExecWriteShellStdinError(execMsgId uint32, execId string, errMsg string) []byte {
wsErr := newMsg("WriteShellStdinError")
setStr(wsErr, "error", errMsg)
result := newMsg("WriteShellStdinResult")
setMsg(result, "error", wsErr)
return encodeExecClientMsg(execMsgId, execId, "write_shell_stdin_result", result)
}
// encodeExecClientMsg wraps an exec result in AgentClientMessage.
// Mirrors sendExec() in cursor-fetch.ts.
func encodeExecClientMsg(id uint32, execId string, resultFieldName string, resultMsg *dynamicpb.Message) []byte {
ecm := newMsg("ExecClientMessage")
setUint32(ecm, "id", id)
// Force set exec_id even if empty - Cursor requires this field to be set
ecm.Set(field(ecm, "exec_id"), protoreflect.ValueOfString(execId))
// Debug: check if field exists
fd := field(ecm, resultFieldName)
if fd == nil {
panic(fmt.Sprintf("field %q NOT FOUND in ExecClientMessage! Available fields: %v", resultFieldName, listFields(ecm)))
}
// Debug: log the actual field being set
log.Debugf("encodeExecClientMsg: setting field %q (number=%d, kind=%s)", fd.Name(), fd.Number(), fd.Kind())
ecm.Set(fd, protoreflect.ValueOfMessage(resultMsg.ProtoReflect()))
acm := newMsg("AgentClientMessage")
setMsg(acm, "exec_client_message", ecm)
return marshal(acm)
}
func listFields(msg *dynamicpb.Message) []string {
var names []string
for i := 0; i < msg.Descriptor().Fields().Len(); i++ {
names = append(names, string(msg.Descriptor().Fields().Get(i).Name()))
}
return names
}
// --- Utilities ---
// jsonToProtobufValueBytes converts a JSON schema (json.RawMessage) to protobuf Value binary.
// This mirrors the TS pattern: toBinary(ValueSchema, fromJson(ValueSchema, jsonSchema))
func jsonToProtobufValueBytes(jsonData json.RawMessage) []byte {
if len(jsonData) == 0 {
return nil
}
var v interface{}
if err := json.Unmarshal(jsonData, &v); err != nil {
return jsonData // fallback to raw JSON if parsing fails
}
pbVal, err := structpb.NewValue(v)
if err != nil {
return jsonData // fallback
}
b, err := proto.Marshal(pbVal)
if err != nil {
return jsonData // fallback
}
return b
}
// ProtobufValueBytesToJSON converts protobuf Value binary back to JSON.
// This mirrors the TS pattern: toJson(ValueSchema, fromBinary(ValueSchema, value))
func ProtobufValueBytesToJSON(data []byte) (interface{}, error) {
val := &structpb.Value{}
if err := proto.Unmarshal(data, val); err != nil {
return nil, err
}
return val.AsInterface(), nil
}
func sha256Sum(data []byte) []byte {
h := sha256.Sum256(data)
return h[:]
}
var idCounter uint64
func generateId() string {
idCounter++
h := sha256.Sum256([]byte{byte(idCounter), byte(idCounter >> 8), byte(idCounter >> 16)})
return hex.EncodeToString(h[:16])
}

View File

@@ -0,0 +1,332 @@
// Package proto provides hand-rolled protobuf encode/decode for Cursor's gRPC API.
// Field numbers are extracted from the TypeScript generated proto/agent_pb.ts in alma-plugins/cursor-auth.
package proto
// AgentClientMessage (msg 118) oneof "message"
const (
ACM_RunRequest = 1 // AgentRunRequest
ACM_ExecClientMessage = 2 // ExecClientMessage
ACM_KvClientMessage = 3 // KvClientMessage
ACM_ConversationAction = 4 // ConversationAction
ACM_ExecClientControlMsg = 5 // ExecClientControlMessage
ACM_InteractionResponse = 6 // InteractionResponse
ACM_ClientHeartbeat = 7 // ClientHeartbeat
)
// AgentServerMessage (msg 119) oneof "message"
const (
ASM_InteractionUpdate = 1 // InteractionUpdate
ASM_ExecServerMessage = 2 // ExecServerMessage
ASM_ConversationCheckpoint = 3 // ConversationStateStructure
ASM_KvServerMessage = 4 // KvServerMessage
ASM_ExecServerControlMessage = 5 // ExecServerControlMessage
ASM_InteractionQuery = 7 // InteractionQuery
)
// AgentRunRequest (msg 91)
const (
ARR_ConversationState = 1 // ConversationStateStructure
ARR_Action = 2 // ConversationAction
ARR_ModelDetails = 3 // ModelDetails
ARR_McpTools = 4 // McpTools
ARR_ConversationId = 5 // string (optional)
)
// ConversationStateStructure (msg 83)
const (
CSS_RootPromptMessagesJson = 1 // repeated bytes
CSS_TurnsOld = 2 // repeated bytes (deprecated)
CSS_Todos = 3 // repeated bytes
CSS_PendingToolCalls = 4 // repeated string
CSS_Turns = 8 // repeated bytes (CURRENT field for turns)
CSS_PreviousWorkspaceUris = 9 // repeated string
CSS_SelfSummaryCount = 17 // uint32
CSS_ReadPaths = 18 // repeated string
)
// ConversationAction (msg 54) oneof "action"
const (
CA_UserMessageAction = 1 // UserMessageAction
)
// UserMessageAction (msg 55)
const (
UMA_UserMessage = 1 // UserMessage
)
// UserMessage (msg 63)
const (
UM_Text = 1 // string
UM_MessageId = 2 // string
UM_SelectedContext = 3 // SelectedContext (optional)
)
// SelectedContext
const (
SC_SelectedImages = 1 // repeated SelectedImage
)
// SelectedImage
const (
SI_BlobId = 1 // bytes (oneof dataOrBlobId)
SI_Uuid = 2 // string
SI_Path = 3 // string
SI_MimeType = 7 // string
SI_Data = 8 // bytes (oneof dataOrBlobId)
)
// ModelDetails (msg 88)
const (
MD_ModelId = 1 // string
MD_ThinkingDetails = 2 // ThinkingDetails (optional)
MD_DisplayModelId = 3 // string
MD_DisplayName = 4 // string
)
// McpTools (msg 307)
const (
MT_McpTools = 1 // repeated McpToolDefinition
)
// McpToolDefinition (msg 306)
const (
MTD_Name = 1 // string
MTD_Description = 2 // string
MTD_InputSchema = 3 // bytes
MTD_ProviderIdentifier = 4 // string
MTD_ToolName = 5 // string
)
// ConversationTurnStructure (msg 70) oneof "turn"
const (
CTS_AgentConversationTurn = 1 // AgentConversationTurnStructure
)
// AgentConversationTurnStructure (msg 72)
const (
ACTS_UserMessage = 1 // bytes (serialized UserMessage)
ACTS_Steps = 2 // repeated bytes (serialized ConversationStep)
)
// ConversationStep (msg 53) oneof "message"
const (
CS_AssistantMessage = 1 // AssistantMessage
)
// AssistantMessage
const (
AM_Text = 1 // string
)
// --- Server-side message fields ---
// InteractionUpdate oneof "message"
const (
IU_TextDelta = 1 // TextDeltaUpdate
IU_ThinkingDelta = 4 // ThinkingDeltaUpdate
IU_ThinkingCompleted = 5 // ThinkingCompletedUpdate
)
// TextDeltaUpdate (msg 92)
const (
TDU_Text = 1 // string
)
// ThinkingDeltaUpdate (msg 97)
const (
TKD_Text = 1 // string
)
// KvServerMessage (msg 271)
const (
KSM_Id = 1 // uint32
KSM_GetBlobArgs = 2 // GetBlobArgs
KSM_SetBlobArgs = 3 // SetBlobArgs
)
// GetBlobArgs (msg 267)
const (
GBA_BlobId = 1 // bytes
)
// SetBlobArgs (msg 269)
const (
SBA_BlobId = 1 // bytes
SBA_BlobData = 2 // bytes
)
// KvClientMessage (msg 272)
const (
KCM_Id = 1 // uint32
KCM_GetBlobResult = 2 // GetBlobResult
KCM_SetBlobResult = 3 // SetBlobResult
)
// GetBlobResult (msg 268)
const (
GBR_BlobData = 1 // bytes (optional)
)
// ExecServerMessage
const (
ESM_Id = 1 // uint32
ESM_ExecId = 15 // string
// oneof message:
ESM_ShellArgs = 2 // ShellArgs
ESM_WriteArgs = 3 // WriteArgs
ESM_DeleteArgs = 4 // DeleteArgs
ESM_GrepArgs = 5 // GrepArgs
ESM_ReadArgs = 7 // ReadArgs (NOTE: 6 is skipped)
ESM_LsArgs = 8 // LsArgs
ESM_DiagnosticsArgs = 9 // DiagnosticsArgs
ESM_RequestContextArgs = 10 // RequestContextArgs
ESM_McpArgs = 11 // McpArgs
ESM_ShellStreamArgs = 14 // ShellArgs (stream variant)
ESM_BackgroundShellSpawn = 16 // BackgroundShellSpawnArgs
ESM_FetchArgs = 20 // FetchArgs
ESM_WriteShellStdinArgs = 23 // WriteShellStdinArgs
)
// ExecClientMessage
const (
ECM_Id = 1 // uint32
ECM_ExecId = 15 // string
// oneof message (mirrors server fields):
ECM_ShellResult = 2
ECM_WriteResult = 3
ECM_DeleteResult = 4
ECM_GrepResult = 5
ECM_ReadResult = 7
ECM_LsResult = 8
ECM_DiagnosticsResult = 9
ECM_RequestContextResult = 10
ECM_McpResult = 11
ECM_ShellStream = 14
ECM_BackgroundShellSpawnRes = 16
ECM_FetchResult = 20
ECM_WriteShellStdinResult = 23
)
// McpArgs
const (
MCA_Name = 1 // string
MCA_Args = 2 // map<string, bytes>
MCA_ToolCallId = 3 // string
MCA_ProviderIdentifier = 4 // string
MCA_ToolName = 5 // string
)
// RequestContextResult oneof "result"
const (
RCR_Success = 1 // RequestContextSuccess
RCR_Error = 2 // RequestContextError
)
// RequestContextSuccess (msg 337)
const (
RCS_RequestContext = 1 // RequestContext
)
// RequestContext
const (
RC_Rules = 2 // repeated CursorRule
RC_Tools = 7 // repeated McpToolDefinition
)
// McpResult oneof "result"
const (
MCR_Success = 1 // McpSuccess
MCR_Error = 2 // McpError
MCR_Rejected = 3 // McpRejected
)
// McpSuccess (msg 290)
const (
MCS_Content = 1 // repeated McpToolResultContentItem
MCS_IsError = 2 // bool
)
// McpToolResultContentItem oneof "content"
const (
MTRCI_Text = 1 // McpTextContent
)
// McpTextContent (msg 287)
const (
MTC_Text = 1 // string
)
// McpError (msg 291)
const (
MCE_Error = 1 // string
)
// --- Rejection messages ---
// ReadRejected: path=1, reason=2
// ShellRejected: command=1, workingDirectory=2, reason=3, isReadonly=4
// WriteRejected: path=1, reason=2
// DeleteRejected: path=1, reason=2
// LsRejected: path=1, reason=2
// GrepError: error=1
// FetchError: url=1, error=2
// WriteShellStdinError: error=1
// ReadResult oneof: success=1, error=2, rejected=3
// ShellResult oneof: success=1 (+ various), rejected=?
// The TS code uses specific result field numbers from the oneof:
const (
RR_Rejected = 3 // ReadResult.rejected
SR_Rejected = 5 // ShellResult.rejected (from TS: ShellResult has success/various/rejected)
WR_Rejected = 5 // WriteResult.rejected
DR_Rejected = 3 // DeleteResult.rejected
LR_Rejected = 3 // LsResult.rejected
GR_Error = 2 // GrepResult.error
FR_Error = 2 // FetchResult.error
BSSR_Rejected = 2 // BackgroundShellSpawnResult.rejected (error field)
WSSR_Error = 2 // WriteShellStdinResult.error
)
// --- Rejection struct fields ---
const (
REJ_Path = 1
REJ_Reason = 2
SREJ_Command = 1
SREJ_WorkingDir = 2
SREJ_Reason = 3
SREJ_IsReadonly = 4
GERR_Error = 1
FERR_Url = 1
FERR_Error = 2
)
// ReadArgs
const (
RA_Path = 1 // string
)
// WriteArgs
const (
WA_Path = 1 // string
)
// DeleteArgs
const (
DA_Path = 1 // string
)
// LsArgs
const (
LA_Path = 1 // string
)
// ShellArgs
const (
SHA_Command = 1 // string
SHA_WorkingDirectory = 2 // string
)
// FetchArgs
const (
FA_Url = 1 // string
)

View File

@@ -0,0 +1,313 @@
package proto
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
const (
defaultInitialWindowSize = 65535 // HTTP/2 default
maxFramePayload = 16384 // HTTP/2 default max frame size
)
// H2Stream provides bidirectional HTTP/2 streaming for the Connect protocol.
// Go's net/http does not support full-duplex HTTP/2, so we use the low-level framer.
type H2Stream struct {
framer *http2.Framer
conn net.Conn
streamID uint32
mu sync.Mutex
id string // unique identifier for debugging
frameNum int64 // sequential frame counter for debugging
dataCh chan []byte
doneCh chan struct{}
err error
// Send-side flow control
sendWindow int32 // available bytes we can send on this stream
connWindow int32 // available bytes on the connection level
windowCond *sync.Cond // signaled when window is updated
windowMu sync.Mutex // protects sendWindow, connWindow
}
// ID returns the unique identifier for this stream (for logging).
func (s *H2Stream) ID() string { return s.id }
// FrameNum returns the current frame number for debugging.
func (s *H2Stream) FrameNum() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.frameNum
}
// DialH2Stream establishes a TLS+HTTP/2 connection and opens a new stream.
func DialH2Stream(host string, headers map[string]string) (*H2Stream, error) {
tlsConn, err := tls.Dial("tcp", host+":443", &tls.Config{
NextProtos: []string{"h2"},
})
if err != nil {
return nil, fmt.Errorf("h2: TLS dial failed: %w", err)
}
if tlsConn.ConnectionState().NegotiatedProtocol != "h2" {
tlsConn.Close()
return nil, fmt.Errorf("h2: server did not negotiate h2")
}
framer := http2.NewFramer(tlsConn, tlsConn)
// Client connection preface
if _, err := tlsConn.Write([]byte(http2.ClientPreface)); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: preface write failed: %w", err)
}
// Send initial SETTINGS (tell server how much WE can receive)
if err := framer.WriteSettings(
http2.Setting{ID: http2.SettingInitialWindowSize, Val: 4 * 1024 * 1024},
http2.Setting{ID: http2.SettingMaxConcurrentStreams, Val: 100},
); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: settings write failed: %w", err)
}
// Connection-level window update (for receiving)
if err := framer.WriteWindowUpdate(0, 3*1024*1024); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: window update failed: %w", err)
}
// Read and handle initial server frames (SETTINGS, WINDOW_UPDATE)
// Track server's initial window size (how much WE can send)
serverInitialWindowSize := int32(defaultInitialWindowSize)
connWindowSize := int32(defaultInitialWindowSize) // connection-level send window
for i := 0; i < 10; i++ {
f, err := framer.ReadFrame()
if err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: initial frame read failed: %w", err)
}
switch sf := f.(type) {
case *http2.SettingsFrame:
if !sf.IsAck() {
sf.ForeachSetting(func(s http2.Setting) error {
if s.ID == http2.SettingInitialWindowSize {
serverInitialWindowSize = int32(s.Val)
log.Debugf("h2: server initial window size: %d", s.Val)
}
return nil
})
framer.WriteSettingsAck()
} else {
goto handshakeDone
}
case *http2.WindowUpdateFrame:
if sf.StreamID == 0 {
connWindowSize += int32(sf.Increment)
log.Debugf("h2: initial conn window update: +%d, total=%d", sf.Increment, connWindowSize)
}
default:
// unexpected but continue
}
}
handshakeDone:
// Build HEADERS
streamID := uint32(1)
var hdrBuf []byte
enc := hpack.NewEncoder(&sliceWriter{buf: &hdrBuf})
enc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
enc.WriteField(hpack.HeaderField{Name: ":authority", Value: host})
if p, ok := headers[":path"]; ok {
enc.WriteField(hpack.HeaderField{Name: ":path", Value: p})
}
for k, v := range headers {
if len(k) > 0 && k[0] == ':' {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
BlockFragment: hdrBuf,
EndStream: false,
EndHeaders: true,
}); err != nil {
tlsConn.Close()
return nil, fmt.Errorf("h2: headers write failed: %w", err)
}
s := &H2Stream{
framer: framer,
conn: tlsConn,
streamID: streamID,
dataCh: make(chan []byte, 256),
doneCh: make(chan struct{}),
id: fmt.Sprintf("%d-%s", streamID, time.Now().Format("150405.000")),
frameNum: 0,
sendWindow: serverInitialWindowSize,
connWindow: connWindowSize,
}
s.windowCond = sync.NewCond(&s.windowMu)
go s.readLoop()
return s, nil
}
// Write sends a DATA frame on the stream, respecting flow control.
func (s *H2Stream) Write(data []byte) error {
for len(data) > 0 {
chunk := data
if len(chunk) > maxFramePayload {
chunk = data[:maxFramePayload]
}
// Wait for flow control window
s.windowMu.Lock()
for s.sendWindow <= 0 || s.connWindow <= 0 {
s.windowCond.Wait()
}
// Limit chunk to available window
allowed := int(s.sendWindow)
if int(s.connWindow) < allowed {
allowed = int(s.connWindow)
}
if len(chunk) > allowed {
chunk = chunk[:allowed]
}
s.sendWindow -= int32(len(chunk))
s.connWindow -= int32(len(chunk))
s.windowMu.Unlock()
s.mu.Lock()
err := s.framer.WriteData(s.streamID, false, chunk)
s.mu.Unlock()
if err != nil {
return err
}
data = data[len(chunk):]
}
return nil
}
// Data returns the channel of received data chunks.
func (s *H2Stream) Data() <-chan []byte { return s.dataCh }
// Done returns a channel closed when the stream ends.
func (s *H2Stream) Done() <-chan struct{} { return s.doneCh }
// Err returns the error (if any) that caused the stream to close.
// Returns nil for a clean shutdown (EOF / StreamEnded).
func (s *H2Stream) Err() error { return s.err }
// Close tears down the connection.
func (s *H2Stream) Close() {
s.conn.Close()
// Unblock any writers waiting on flow control
s.windowCond.Broadcast()
}
func (s *H2Stream) readLoop() {
defer close(s.doneCh)
defer close(s.dataCh)
for {
f, err := s.framer.ReadFrame()
if err != nil {
if err != io.EOF {
s.err = err
log.Debugf("h2stream[%s]: readLoop error: %v", s.id, err)
}
return
}
// Increment frame counter
s.mu.Lock()
s.frameNum++
s.mu.Unlock()
switch frame := f.(type) {
case *http2.DataFrame:
if frame.StreamID == s.streamID && len(frame.Data()) > 0 {
cp := make([]byte, len(frame.Data()))
copy(cp, frame.Data())
s.dataCh <- cp
// Flow control: send WINDOW_UPDATE for received data
s.mu.Lock()
s.framer.WriteWindowUpdate(0, uint32(len(cp)))
s.framer.WriteWindowUpdate(s.streamID, uint32(len(cp)))
s.mu.Unlock()
}
if frame.StreamEnded() {
return
}
case *http2.HeadersFrame:
if frame.StreamEnded() {
return
}
case *http2.RSTStreamFrame:
s.err = fmt.Errorf("h2: RST_STREAM code=%d", frame.ErrCode)
log.Debugf("h2stream[%s]: received RST_STREAM code=%d", s.id, frame.ErrCode)
return
case *http2.GoAwayFrame:
s.err = fmt.Errorf("h2: GOAWAY code=%d", frame.ErrCode)
return
case *http2.PingFrame:
if !frame.IsAck() {
s.mu.Lock()
s.framer.WritePing(true, frame.Data)
s.mu.Unlock()
}
case *http2.SettingsFrame:
if !frame.IsAck() {
// Check for window size changes
frame.ForeachSetting(func(setting http2.Setting) error {
if setting.ID == http2.SettingInitialWindowSize {
s.windowMu.Lock()
delta := int32(setting.Val) - s.sendWindow
s.sendWindow += delta
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
return nil
})
s.mu.Lock()
s.framer.WriteSettingsAck()
s.mu.Unlock()
}
case *http2.WindowUpdateFrame:
// Update send-side flow control window
s.windowMu.Lock()
if frame.StreamID == 0 {
s.connWindow += int32(frame.Increment)
} else if frame.StreamID == s.streamID {
s.sendWindow += int32(frame.Increment)
}
s.windowMu.Unlock()
s.windowCond.Broadcast()
}
}
}
type sliceWriter struct{ buf *[]byte }
func (w *sliceWriter) Write(p []byte) (int, error) {
*w.buf = append(*w.buf, p...)
return len(p), nil
}

View File

@@ -305,6 +305,9 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
defer manualPromptTimer.Stop()
}
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback:
for {
select {
@@ -326,13 +329,14 @@ waitForCallback:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the Gemini callback URL (or press Enter to keep waiting): ")
continue
case input := <-manualInputCh:
manualInputCh = nil
manualInputErrCh = nil
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
@@ -345,6 +349,8 @@ waitForCallback:
}
authCode = parsed.Code
break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}

View File

@@ -5,8 +5,7 @@ import (
)
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
// authenticators and a file-based token store.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
@@ -24,6 +23,8 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
sdkAuth.NewGitLabAuthenticator(),
sdkAuth.NewCodeBuddyAuthenticator(),
sdkAuth.NewCursorAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,43 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens.
// It initiates the OAuth authentication, displays the user code for the user to enter
// at the CodeBuddy verification URL, and waits for authorization before saving the tokens.
//
// Parameters:
// - cfg: The application configuration containing proxy and auth directory settings
// - options: Login options including browser behavior settings
func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
}
record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts)
if err != nil {
log.Errorf("CodeBuddy authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("CodeBuddy authentication successful!")
}

View File

@@ -0,0 +1,37 @@
package cmd
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens.
func DoCursorLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts)
if err != nil {
log.Errorf("Cursor authentication failed: %v", err)
return
}
if savedPath != "" {
log.Infof("Authentication saved to %s", savedPath)
}
if record != nil && record.Label != "" {
log.Infof("Authenticated as %s", record.Label)
}
log.Info("Cursor authentication successful!")
}

View File

@@ -0,0 +1,55 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_ClaudeHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
claude-header-defaults:
user-agent: " claude-cli/2.1.70 (external, cli) "
package-version: " 0.80.0 "
runtime-version: " v24.5.0 "
os: " MacOS "
arch: " arm64 "
timeout: " 900 "
stabilize-device-profile: false
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.ClaudeHeaderDefaults.UserAgent; got != "claude-cli/2.1.70 (external, cli)" {
t.Fatalf("UserAgent = %q, want %q", got, "claude-cli/2.1.70 (external, cli)")
}
if got := cfg.ClaudeHeaderDefaults.PackageVersion; got != "0.80.0" {
t.Fatalf("PackageVersion = %q, want %q", got, "0.80.0")
}
if got := cfg.ClaudeHeaderDefaults.RuntimeVersion; got != "v24.5.0" {
t.Fatalf("RuntimeVersion = %q, want %q", got, "v24.5.0")
}
if got := cfg.ClaudeHeaderDefaults.OS; got != "MacOS" {
t.Fatalf("OS = %q, want %q", got, "MacOS")
}
if got := cfg.ClaudeHeaderDefaults.Arch; got != "arm64" {
t.Fatalf("Arch = %q, want %q", got, "arm64")
}
if got := cfg.ClaudeHeaderDefaults.Timeout; got != "900" {
t.Fatalf("Timeout = %q, want %q", got, "900")
}
if cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
t.Fatal("StabilizeDeviceProfile = nil, want non-nil")
}
if got := *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile; got {
t.Fatalf("StabilizeDeviceProfile = %v, want false", got)
}
}

View File

@@ -13,6 +13,7 @@ import (
"strings"
"syscall"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
@@ -145,13 +146,19 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
// ClaudeHeaderDefaults configures default header values injected into Claude API requests.
// In legacy mode, UserAgent/PackageVersion/RuntimeVersion/Timeout act as fallbacks when
// the client omits them, while OS/Arch remain runtime-derived. When stabilized device
// profiles are enabled, OS/Arch become the pinned platform baseline, while
// UserAgent/PackageVersion/RuntimeVersion seed the upgradeable software fingerprint.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
OS string `yaml:"os" json:"os"`
Arch string `yaml:"arch" json:"arch"`
Timeout string `yaml:"timeout" json:"timeout"`
StabilizeDeviceProfile *bool `yaml:"stabilize-device-profile,omitempty" json:"stabilize-device-profile,omitempty"`
}
// CodexHeaderDefaults configures fallback header values injected into Codex
@@ -188,6 +195,9 @@ type RemoteManagement struct {
SecretKey string `yaml:"secret-key"`
// DisableControlPanel skips serving and syncing the bundled management UI when true.
DisableControlPanel bool `yaml:"disable-control-panel"`
// DisableAutoUpdatePanel disables automatic periodic background updates of the management panel asset from GitHub.
// When false (the default), the background updater remains enabled; when true, the panel is only downloaded on first access if missing.
DisableAutoUpdatePanel bool `yaml:"disable-auto-update-panel"`
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
PanelGitHubRepository string `yaml:"panel-github-repository"`
@@ -568,6 +578,10 @@ type OpenAICompatibilityModel struct {
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
// Thinking configures the thinking/reasoning capability for this model.
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
}
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
@@ -694,6 +708,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Codex header defaults.
cfg.SanitizeCodexHeaderDefaults()
// Sanitize Claude header defaults.
cfg.SanitizeClaudeHeaderDefaults()
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
@@ -796,6 +813,20 @@ func (cfg *Config) SanitizeCodexHeaderDefaults() {
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
// SanitizeClaudeHeaderDefaults trims surrounding whitespace from the
// configured Claude fingerprint baseline values.
func (cfg *Config) SanitizeClaudeHeaderDefaults() {
if cfg == nil {
return
}
cfg.ClaudeHeaderDefaults.UserAgent = strings.TrimSpace(cfg.ClaudeHeaderDefaults.UserAgent)
cfg.ClaudeHeaderDefaults.PackageVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.PackageVersion)
cfg.ClaudeHeaderDefaults.RuntimeVersion = strings.TrimSpace(cfg.ClaudeHeaderDefaults.RuntimeVersion)
cfg.ClaudeHeaderDefaults.OS = strings.TrimSpace(cfg.ClaudeHeaderDefaults.OS)
cfg.ClaudeHeaderDefaults.Arch = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Arch)
cfg.ClaudeHeaderDefaults.Timeout = strings.TrimSpace(cfg.ClaudeHeaderDefaults.Timeout)
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.

View File

@@ -31,6 +31,7 @@ const (
httpUserAgent = "CLIProxyAPI-management-updater"
managementSyncMinInterval = 30 * time.Second
updateCheckInterval = 3 * time.Hour
maxAssetDownloadSize = 50 << 20 // 10 MB safety limit for management asset downloads
)
// ManagementFileName exposes the control panel asset filename.
@@ -88,6 +89,10 @@ func runAutoUpdater(ctx context.Context) {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
if cfg.RemoteManagement.DisableAutoUpdatePanel {
log.Debug("management asset auto-updater skipped: disable-auto-update-panel is enabled")
return
}
configPath, _ := schedulerConfigPath.Load().(string)
staticDir := StaticDir(configPath)
@@ -259,7 +264,8 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
log.Errorf("management asset digest mismatch: expected %s got %s — aborting update for safety", remoteHash, downloadedHash)
return nil, nil
}
if err = atomicWriteFile(localPath, data); err != nil {
@@ -282,6 +288,9 @@ func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, loca
return false
}
log.Warnf("management asset downloaded from fallback URL without digest verification (hash=%s) — "+
"enable verified GitHub updates by keeping disable-auto-update-panel set to false", downloadedHash)
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
@@ -392,10 +401,13 @@ func downloadAsset(ctx context.Context, client *http.Client, downloadURL string)
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
data, err := io.ReadAll(resp.Body)
data, err := io.ReadAll(io.LimitReader(resp.Body, maxAssetDownloadSize+1))
if err != nil {
return nil, "", fmt.Errorf("read download body: %w", err)
}
if int64(len(data)) > maxAssetDownloadSize {
return nil, "", fmt.Errorf("download exceeds maximum allowed size of %d bytes", maxAssetDownloadSize)
}
sum := sha256.Sum256(data)
return data, hex.EncodeToString(sum[:]), nil

View File

@@ -30,6 +30,23 @@ type OAuthCallback struct {
ErrorDescription string
}
// AsyncPrompt runs a prompt function in a goroutine and returns channels for
// the result. The returned channels are buffered (size 1) so the goroutine can
// complete even if the caller abandons the channels.
func AsyncPrompt(promptFn func(string) (string, error), message string) (<-chan string, <-chan error) {
inputCh := make(chan string, 1)
errCh := make(chan error, 1)
go func() {
input, err := promptFn(message)
if err != nil {
errCh <- err
return
}
inputCh <- input
}()
return inputCh, errCh
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {

View File

@@ -88,6 +88,87 @@ func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// GetCodeBuddyModels returns the available models for CodeBuddy (Tencent).
// These models are served through the copilot.tencent.com API.
func GetCodeBuddyModels() []*ModelInfo {
now := int64(1748044800) // 2025-05-24
return []*ModelInfo{
{
ID: "glm-5.0",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-5.0",
Description: "GLM-5.0 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "glm-4.7",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "GLM-4.7",
Description: "GLM-4.7 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "minimax-m2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "MiniMax M2.5",
Description: "MiniMax M2.5 via CodeBuddy",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "kimi-k2.5",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Kimi K2.5",
Description: "Kimi K2.5 via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "deepseek-v3-2-volc",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "DeepSeek V3.2 (Volc)",
Description: "DeepSeek V3.2 via CodeBuddy (Volcano Engine)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "hunyuan-2.0-thinking",
Object: "model",
Created: now,
OwnedBy: "tencent",
Type: "codebuddy",
DisplayName: "Hunyuan 2.0 Thinking",
Description: "Tencent Hunyuan 2.0 Thinking via CodeBuddy",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{ZeroAllowed: true},
SupportedEndpoints: []string{"/chat/completions"},
},
}
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
@@ -148,11 +229,27 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetAmazonQModels()
case "antigravity":
return GetAntigravityModels()
case "codebuddy":
return GetCodeBuddyModels()
case "cursor":
return GetCursorModels()
default:
return nil
}
}
// GetCursorModels returns the fallback Cursor model definitions.
func GetCursorModels() []*ModelInfo {
return []*ModelInfo{
{ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
{ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192},
{ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384},
{ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000},
{ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}},
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
@@ -176,6 +273,8 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
GetCodeBuddyModels(),
GetCursorModels(),
}
for _, models := range allModels {
for _, m := range models {
@@ -365,6 +464,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.4",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.4",
Description: "OpenAI GPT-5.4 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",

View File

@@ -73,16 +73,16 @@ type availableModelsCacheEntry struct {
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
// Min is the minimum allowed thinking budget (inclusive).
Min int `json:"min,omitempty"`
Min int `json:"min,omitempty" yaml:"min,omitempty"`
// Max is the maximum allowed thinking budget (inclusive).
Max int `json:"max,omitempty"`
Max int `json:"max,omitempty" yaml:"max,omitempty"`
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
ZeroAllowed bool `json:"zero_allowed,omitempty"`
ZeroAllowed bool `json:"zero_allowed,omitempty" yaml:"zero-allowed,omitempty"`
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
DynamicAllowed bool `json:"dynamic_allowed,omitempty" yaml:"dynamic-allowed,omitempty"`
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
// When set, the model uses level-based reasoning instead of token budgets.
Levels []string `json:"levels,omitempty"`
Levels []string `json:"levels,omitempty" yaml:"levels,omitempty"`
}
// ModelRegistration tracks a model's availability

View File

@@ -164,7 +164,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON(out), Headers: wsResp.Headers.Clone()}
return resp, nil
}
@@ -280,7 +280,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
break
}
@@ -296,7 +296,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return false
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
}
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes the authentication credentials (no-op for AI Studio).

View File

@@ -308,7 +308,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -512,7 +512,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -691,31 +691,42 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
partsJSON, _ := json.Marshal(parts)
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
updatedTemplate, _ := sjson.SetRawBytes([]byte(responseTemplate), "candidates.0.content.parts", partsJSON)
responseTemplate = string(updatedTemplate)
if role != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.content.role", role)
responseTemplate = string(updatedTemplate)
}
if finishReason != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "candidates.0.finishReason", finishReason)
responseTemplate = string(updatedTemplate)
}
if modelVersion != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "modelVersion", modelVersion)
responseTemplate = string(updatedTemplate)
}
if responseID != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "responseId", responseID)
responseTemplate = string(updatedTemplate)
}
if usageRaw != "" {
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
updatedTemplate, _ = sjson.SetRawBytes([]byte(responseTemplate), "usageMetadata", []byte(usageRaw))
responseTemplate = string(updatedTemplate)
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.promptTokenCount", 0)
responseTemplate = string(updatedTemplate)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.candidatesTokenCount", 0)
responseTemplate = string(updatedTemplate)
updatedTemplate, _ = sjson.SetBytes([]byte(responseTemplate), "usageMetadata.totalTokenCount", 0)
responseTemplate = string(updatedTemplate)
}
output := `{"response":{},"traceId":""}`
output, _ = sjson.SetRaw(output, "response", responseTemplate)
updatedOutput, _ := sjson.SetRawBytes([]byte(output), "response", []byte(responseTemplate))
output = string(updatedOutput)
if traceID != "" {
output, _ = sjson.Set(output, "traceId", traceID)
updatedOutput, _ = sjson.SetBytes([]byte(output), "traceId", traceID)
output = string(updatedOutput)
}
return []byte(output)
}
@@ -880,12 +891,12 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -1043,7 +1054,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode
@@ -1265,19 +1276,20 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// if useAntigravitySchema {
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.SetBytes([]byte(payloadStr), "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
// for _, partResult := range systemInstructionPartsResult.Array() {
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
// payloadStr, _ = sjson.SetRawBytes([]byte(payloadStr), "request.systemInstruction.parts.-1", []byte(partResult.Raw))
// }
// }
// }
if strings.Contains(modelName, "claude") {
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
updated, _ := sjson.SetBytes([]byte(payloadStr), "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
payloadStr = string(updated)
} else {
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
}
@@ -1499,8 +1511,9 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
}
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template := payload
template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.SetBytes(template, "userAgent", "antigravity")
isImageModel := strings.Contains(modelName, "image")
@@ -1510,28 +1523,28 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
} else {
reqType = "agent"
}
template, _ = sjson.Set(template, "requestType", reqType)
template, _ = sjson.SetBytes(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.Set(template, "project", projectID)
template, _ = sjson.SetBytes(template, "project", projectID)
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
template, _ = sjson.SetBytes(template, "project", generateProjectID())
}
if isImageModel {
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
template, _ = sjson.SetBytes(template, "requestId", generateImageGenRequestID())
} else {
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.SetBytes(template, "requestId", generateRequestID())
template, _ = sjson.SetBytes(template, "request.sessionId", generateStableSessionID(payload))
}
template, _ = sjson.Delete(template, "request.safetySettings")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
template, _ = sjson.Delete(template, "toolConfig")
template, _ = sjson.DeleteBytes(template, "request.safetySettings")
if toolConfig := gjson.GetBytes(template, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRawBytes(template, "request.toolConfig", []byte(toolConfig.Raw))
template, _ = sjson.DeleteBytes(template, "toolConfig")
}
return []byte(template)
return template
}
func generateRequestID() string {

View File

@@ -0,0 +1,383 @@
package executor
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
const (
defaultClaudeFingerprintUserAgent = "claude-cli/2.1.63 (external, cli)"
defaultClaudeFingerprintPackageVersion = "0.74.0"
defaultClaudeFingerprintRuntimeVersion = "v24.3.0"
defaultClaudeFingerprintOS = "MacOS"
defaultClaudeFingerprintArch = "arm64"
claudeDeviceProfileTTL = 7 * 24 * time.Hour
claudeDeviceProfileCleanupPeriod = time.Hour
)
var (
claudeCLIVersionPattern = regexp.MustCompile(`^claude-cli/(\d+)\.(\d+)\.(\d+)`)
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu sync.RWMutex
claudeDeviceProfileCacheCleanupOnce sync.Once
claudeDeviceProfileBeforeCandidateStore func(claudeDeviceProfile)
)
type claudeCLIVersion struct {
major int
minor int
patch int
}
func (v claudeCLIVersion) Compare(other claudeCLIVersion) int {
switch {
case v.major != other.major:
if v.major > other.major {
return 1
}
return -1
case v.minor != other.minor:
if v.minor > other.minor {
return 1
}
return -1
case v.patch != other.patch:
if v.patch > other.patch {
return 1
}
return -1
default:
return 0
}
}
type claudeDeviceProfile struct {
UserAgent string
PackageVersion string
RuntimeVersion string
OS string
Arch string
Version claudeCLIVersion
HasVersion bool
}
type claudeDeviceProfileCacheEntry struct {
profile claudeDeviceProfile
expire time.Time
}
func claudeDeviceProfileStabilizationEnabled(cfg *config.Config) bool {
if cfg == nil || cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile == nil {
return false
}
return *cfg.ClaudeHeaderDefaults.StabilizeDeviceProfile
}
func defaultClaudeDeviceProfile(cfg *config.Config) claudeDeviceProfile {
hdrDefault := func(cfgVal, fallback string) string {
if strings.TrimSpace(cfgVal) != "" {
return strings.TrimSpace(cfgVal)
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
profile := claudeDeviceProfile{
UserAgent: hdrDefault(hd.UserAgent, defaultClaudeFingerprintUserAgent),
PackageVersion: hdrDefault(hd.PackageVersion, defaultClaudeFingerprintPackageVersion),
RuntimeVersion: hdrDefault(hd.RuntimeVersion, defaultClaudeFingerprintRuntimeVersion),
OS: hdrDefault(hd.OS, defaultClaudeFingerprintOS),
Arch: hdrDefault(hd.Arch, defaultClaudeFingerprintArch),
}
if version, ok := parseClaudeCLIVersion(profile.UserAgent); ok {
profile.Version = version
profile.HasVersion = true
}
return profile
}
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func parseClaudeCLIVersion(userAgent string) (claudeCLIVersion, bool) {
matches := claudeCLIVersionPattern.FindStringSubmatch(strings.TrimSpace(userAgent))
if len(matches) != 4 {
return claudeCLIVersion{}, false
}
major, err := strconv.Atoi(matches[1])
if err != nil {
return claudeCLIVersion{}, false
}
minor, err := strconv.Atoi(matches[2])
if err != nil {
return claudeCLIVersion{}, false
}
patch, err := strconv.Atoi(matches[3])
if err != nil {
return claudeCLIVersion{}, false
}
return claudeCLIVersion{major: major, minor: minor, patch: patch}, true
}
func shouldUpgradeClaudeDeviceProfile(candidate, current claudeDeviceProfile) bool {
if candidate.UserAgent == "" || !candidate.HasVersion {
return false
}
if current.UserAgent == "" || !current.HasVersion {
return true
}
return candidate.Version.Compare(current.Version) > 0
}
func pinClaudeDeviceProfilePlatform(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
profile.OS = baseline.OS
profile.Arch = baseline.Arch
return profile
}
// normalizeClaudeDeviceProfile keeps stabilized profiles pinned to the current
// baseline platform and enforces the baseline software fingerprint as a floor.
func normalizeClaudeDeviceProfile(profile, baseline claudeDeviceProfile) claudeDeviceProfile {
profile = pinClaudeDeviceProfilePlatform(profile, baseline)
if profile.UserAgent == "" || !profile.HasVersion || shouldUpgradeClaudeDeviceProfile(baseline, profile) {
profile.UserAgent = baseline.UserAgent
profile.PackageVersion = baseline.PackageVersion
profile.RuntimeVersion = baseline.RuntimeVersion
profile.Version = baseline.Version
profile.HasVersion = baseline.HasVersion
}
return profile
}
func extractClaudeDeviceProfile(headers http.Header, cfg *config.Config) (claudeDeviceProfile, bool) {
if headers == nil {
return claudeDeviceProfile{}, false
}
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
version, ok := parseClaudeCLIVersion(userAgent)
if !ok {
return claudeDeviceProfile{}, false
}
baseline := defaultClaudeDeviceProfile(cfg)
profile := claudeDeviceProfile{
UserAgent: userAgent,
PackageVersion: firstNonEmptyHeader(headers, "X-Stainless-Package-Version", baseline.PackageVersion),
RuntimeVersion: firstNonEmptyHeader(headers, "X-Stainless-Runtime-Version", baseline.RuntimeVersion),
OS: firstNonEmptyHeader(headers, "X-Stainless-Os", baseline.OS),
Arch: firstNonEmptyHeader(headers, "X-Stainless-Arch", baseline.Arch),
Version: version,
HasVersion: true,
}
return profile, true
}
func firstNonEmptyHeader(headers http.Header, name, fallback string) string {
if headers == nil {
return fallback
}
if value := strings.TrimSpace(headers.Get(name)); value != "" {
return value
}
return fallback
}
func claudeDeviceProfileScopeKey(auth *cliproxyauth.Auth, apiKey string) string {
switch {
case auth != nil && strings.TrimSpace(auth.ID) != "":
return "auth:" + strings.TrimSpace(auth.ID)
case strings.TrimSpace(apiKey) != "":
return "api_key:" + strings.TrimSpace(apiKey)
default:
return "global"
}
}
func claudeDeviceProfileCacheKey(auth *cliproxyauth.Auth, apiKey string) string {
sum := sha256.Sum256([]byte(claudeDeviceProfileScopeKey(auth, apiKey)))
return hex.EncodeToString(sum[:])
}
func startClaudeDeviceProfileCacheCleanup() {
go func() {
ticker := time.NewTicker(claudeDeviceProfileCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredClaudeDeviceProfiles()
}
}()
}
func purgeExpiredClaudeDeviceProfiles() {
now := time.Now()
claudeDeviceProfileCacheMu.Lock()
for key, entry := range claudeDeviceProfileCache {
if !entry.expire.After(now) {
delete(claudeDeviceProfileCache, key)
}
}
claudeDeviceProfileCacheMu.Unlock()
}
func resolveClaudeDeviceProfile(auth *cliproxyauth.Auth, apiKey string, headers http.Header, cfg *config.Config) claudeDeviceProfile {
claudeDeviceProfileCacheCleanupOnce.Do(startClaudeDeviceProfileCacheCleanup)
cacheKey := claudeDeviceProfileCacheKey(auth, apiKey)
now := time.Now()
baseline := defaultClaudeDeviceProfile(cfg)
candidate, hasCandidate := extractClaudeDeviceProfile(headers, cfg)
if hasCandidate {
candidate = pinClaudeDeviceProfilePlatform(candidate, baseline)
}
if hasCandidate && !shouldUpgradeClaudeDeviceProfile(candidate, baseline) {
hasCandidate = false
}
claudeDeviceProfileCacheMu.RLock()
entry, hasCached := claudeDeviceProfileCache[cacheKey]
cachedValid := hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
claudeDeviceProfileCacheMu.RUnlock()
if hasCandidate {
if claudeDeviceProfileBeforeCandidateStore != nil {
claudeDeviceProfileBeforeCandidateStore(candidate)
}
claudeDeviceProfileCacheMu.Lock()
entry, hasCached = claudeDeviceProfileCache[cacheKey]
cachedValid = hasCached && entry.expire.After(now) && entry.profile.UserAgent != ""
if cachedValid {
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
}
if cachedValid && !shouldUpgradeClaudeDeviceProfile(candidate, entry.profile) {
entry.expire = now.Add(claudeDeviceProfileTTL)
claudeDeviceProfileCache[cacheKey] = entry
claudeDeviceProfileCacheMu.Unlock()
return entry.profile
}
claudeDeviceProfileCache[cacheKey] = claudeDeviceProfileCacheEntry{
profile: candidate,
expire: now.Add(claudeDeviceProfileTTL),
}
claudeDeviceProfileCacheMu.Unlock()
return candidate
}
if cachedValid {
claudeDeviceProfileCacheMu.Lock()
entry = claudeDeviceProfileCache[cacheKey]
if entry.expire.After(now) && entry.profile.UserAgent != "" {
entry.profile = normalizeClaudeDeviceProfile(entry.profile, baseline)
entry.expire = now.Add(claudeDeviceProfileTTL)
claudeDeviceProfileCache[cacheKey] = entry
claudeDeviceProfileCacheMu.Unlock()
return entry.profile
}
claudeDeviceProfileCacheMu.Unlock()
}
return baseline
}
func applyClaudeDeviceProfileHeaders(r *http.Request, profile claudeDeviceProfile) {
if r == nil {
return
}
for _, headerName := range []string{
"User-Agent",
"X-Stainless-Package-Version",
"X-Stainless-Runtime-Version",
"X-Stainless-Os",
"X-Stainless-Arch",
} {
r.Header.Del(headerName)
}
r.Header.Set("User-Agent", profile.UserAgent)
r.Header.Set("X-Stainless-Package-Version", profile.PackageVersion)
r.Header.Set("X-Stainless-Runtime-Version", profile.RuntimeVersion)
r.Header.Set("X-Stainless-Os", profile.OS)
r.Header.Set("X-Stainless-Arch", profile.Arch)
}
func applyClaudeLegacyDeviceHeaders(r *http.Request, ginHeaders http.Header, cfg *config.Config) {
if r == nil {
return
}
profile := defaultClaudeDeviceProfile(cfg)
miscEnsure := func(name, fallback string) {
if strings.TrimSpace(r.Header.Get(name)) != "" {
return
}
if strings.TrimSpace(ginHeaders.Get(name)) != "" {
r.Header.Set(name, strings.TrimSpace(ginHeaders.Get(name)))
return
}
r.Header.Set(name, fallback)
}
miscEnsure("X-Stainless-Runtime-Version", profile.RuntimeVersion)
miscEnsure("X-Stainless-Package-Version", profile.PackageVersion)
miscEnsure("X-Stainless-Os", mapStainlessOS())
miscEnsure("X-Stainless-Arch", mapStainlessArch())
// Legacy mode preserves per-auth custom header overrides. By the time we get
// here, ApplyCustomHeadersFromAttrs has already populated r.Header.
if strings.TrimSpace(r.Header.Get("User-Agent")) != "" {
return
}
clientUA := ""
if ginHeaders != nil {
clientUA = strings.TrimSpace(ginHeaders.Get("User-Agent"))
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
return
}
r.Header.Set("User-Agent", profile.UserAgent)
}

View File

@@ -14,7 +14,6 @@ import (
"io"
"net/http"
"net/textproto"
"runtime"
"strings"
"time"
@@ -255,7 +254,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -443,7 +442,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -561,7 +560,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -767,36 +766,6 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
@@ -824,6 +793,11 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
stabilizeDeviceProfile := claudeDeviceProfileStabilizationEnabled(cfg)
var deviceProfile claudeDeviceProfile
if stabilizeDeviceProfile {
deviceProfile = resolveClaudeDeviceProfile(auth, apiKey, ginHeaders, cfg)
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
@@ -867,25 +841,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
// For User-Agent, only forward the client's header if it's already a Claude Code client.
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
// to avoid leaking the real client identity during cloaking.
clientUA := ""
if ginHeaders != nil {
clientUA = ginHeaders.Get("User-Agent")
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
} else {
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
}
r.Header.Set("Connection", "keep-alive")
if stream {
r.Header.Set("Accept", "text/event-stream")
@@ -897,13 +855,19 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
r.Header.Set("Accept", "application/json")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
// Legacy mode keeps OS/Arch runtime-derived; stabilized mode pins OS/Arch
// to the configured baseline while still allowing newer official
// User-Agent/package/runtime tuples to upgrade the software fingerprint.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
if stabilizeDeviceProfile {
applyClaudeDeviceProfileHeaders(r, deviceProfile)
} else {
applyClaudeLegacyDeviceHeaders(r, ginHeaders, cfg)
}
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line
// scanner regardless of user preference, so this is non-negotiable for streams.
@@ -1260,7 +1224,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw
if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
updated, _ := sjson.SetBytes([]byte(partJSON), "cache_control.type", "ephemeral")
partJSON = string(updated)
}
result += "," + partJSON
}
@@ -1268,7 +1233,8 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
updated, _ := sjson.SetBytes([]byte(partJSON), "text", system.String())
partJSON = string(updated)
result += "," + partJSON
}
result += "]"

View File

@@ -8,8 +8,11 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -19,6 +22,587 @@ import (
"github.com/tidwall/sjson"
)
func resetClaudeDeviceProfileCache() {
claudeDeviceProfileCacheMu.Lock()
claudeDeviceProfileCache = make(map[string]claudeDeviceProfileCacheEntry)
claudeDeviceProfileCacheMu.Unlock()
}
func newClaudeHeaderTestRequest(t *testing.T, incoming http.Header) *http.Request {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginReq := httptest.NewRequest(http.MethodPost, "http://localhost/v1/messages", nil)
ginReq.Header = incoming.Clone()
ginCtx.Request = ginReq
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
return req.WithContext(context.WithValue(req.Context(), "gin", ginCtx))
}
func assertClaudeFingerprint(t *testing.T, headers http.Header, userAgent, pkgVersion, runtimeVersion, osName, arch string) {
t.Helper()
if got := headers.Get("User-Agent"); got != userAgent {
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
}
if got := headers.Get("X-Stainless-Package-Version"); got != pkgVersion {
t.Fatalf("X-Stainless-Package-Version = %q, want %q", got, pkgVersion)
}
if got := headers.Get("X-Stainless-Runtime-Version"); got != runtimeVersion {
t.Fatalf("X-Stainless-Runtime-Version = %q, want %q", got, runtimeVersion)
}
if got := headers.Get("X-Stainless-Os"); got != osName {
t.Fatalf("X-Stainless-Os = %q, want %q", got, osName)
}
if got := headers.Get("X-Stainless-Arch"); got != arch {
t.Fatalf("X-Stainless-Arch = %q, want %q", got, arch)
}
}
func TestApplyClaudeHeaders_UsesConfiguredBaselineFingerprint(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
Timeout: "900",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline",
Attributes: map[string]string{
"api_key": "key-baseline",
"header:User-Agent": "evil-client/9.9",
"header:X-Stainless-Os": "Linux",
"header:X-Stainless-Arch": "x64",
"header:X-Stainless-Package-Version": "9.9.9",
},
}
incoming := http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
}
req := newClaudeHeaderTestRequest(t, incoming)
applyClaudeHeaders(req, auth, "key-baseline", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
if got := req.Header.Get("X-Stainless-Timeout"); got != "900" {
t.Fatalf("X-Stainless-Timeout = %q, want %q", got, "900")
}
}
func TestApplyClaudeHeaders_TracksHighestClaudeCLIFingerprint(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-upgrade",
Attributes: map[string]string{
"api_key": "key-upgrade",
},
}
firstReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(firstReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"lobe-chat/1.0"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "MacOS", "arm64")
higherReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
"X-Stainless-Os": []string{"MacOS"},
"X-Stainless-Arch": []string{"arm64"},
})
applyClaudeHeaders(higherReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, higherReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.73.0"},
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(lowerReq, auth, "key-upgrade", false, nil, cfg)
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.63 (external, cli)", "0.75.0", "v24.4.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_DoesNotDowngradeConfiguredBaselineOnFirstClaudeClient(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline-floor",
Attributes: map[string]string{
"api_key": "key-baseline-floor",
},
}
olderClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(olderClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
assertClaudeFingerprint(t, olderClaudeReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
newerClaudeReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.81.0"},
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(newerClaudeReq, auth, "key-baseline-floor", false, nil, cfg)
assertClaudeFingerprint(t, newerClaudeReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_UpgradesCachedSoftwareFingerprintWhenBaselineAdvances(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
oldCfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
newCfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.77 (external, cli)",
PackageVersion: "0.87.0",
RuntimeVersion: "v24.8.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-baseline-reload",
Attributes: map[string]string{
"api_key": "key-baseline-reload",
},
}
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.71 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.81.0"},
"X-Stainless-Runtime-Version": []string{"v24.6.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-baseline-reload", false, nil, oldCfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.71 (external, cli)", "0.81.0", "v24.6.0", "MacOS", "arm64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-baseline-reload", false, nil, newCfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_LearnsOfficialFingerprintAfterCustomBaselineFallback(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "my-gateway/1.0",
PackageVersion: "custom-pkg",
RuntimeVersion: "custom-runtime",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-custom-baseline-learning",
Attributes: map[string]string{
"api_key": "key-custom-baseline-learning",
},
}
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "my-gateway/1.0", "custom-pkg", "custom-runtime", "MacOS", "arm64")
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.87.0"},
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
postLearningThirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(postLearningThirdPartyReq, auth, "key-custom-baseline-learning", false, nil, cfg)
assertClaudeFingerprint(t, postLearningThirdPartyReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestResolveClaudeDeviceProfile_RechecksCacheBeforeStoringCandidate(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-racy-upgrade",
Attributes: map[string]string{
"api_key": "key-racy-upgrade",
},
}
lowPaused := make(chan struct{})
releaseLow := make(chan struct{})
var pauseOnce sync.Once
var releaseOnce sync.Once
claudeDeviceProfileBeforeCandidateStore = func(candidate claudeDeviceProfile) {
if candidate.UserAgent != "claude-cli/2.1.62 (external, cli)" {
return
}
pauseOnce.Do(func() { close(lowPaused) })
<-releaseLow
}
t.Cleanup(func() {
claudeDeviceProfileBeforeCandidateStore = nil
releaseOnce.Do(func() { close(releaseLow) })
})
lowResultCh := make(chan claudeDeviceProfile, 1)
go func() {
lowResultCh <- resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
}, cfg)
}()
select {
case <-lowPaused:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for lower candidate to pause before storing")
}
highResult := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"claude-cli/2.1.63 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.75.0"},
"X-Stainless-Runtime-Version": []string{"v24.4.0"},
"X-Stainless-Os": []string{"MacOS"},
"X-Stainless-Arch": []string{"arm64"},
}, cfg)
releaseOnce.Do(func() { close(releaseLow) })
select {
case lowResult := <-lowResultCh:
if lowResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("lowResult.UserAgent = %q, want %q", lowResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if lowResult.PackageVersion != "0.75.0" {
t.Fatalf("lowResult.PackageVersion = %q, want %q", lowResult.PackageVersion, "0.75.0")
}
if lowResult.OS != "MacOS" || lowResult.Arch != "arm64" {
t.Fatalf("lowResult platform = %s/%s, want %s/%s", lowResult.OS, lowResult.Arch, "MacOS", "arm64")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for lower candidate result")
}
if highResult.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("highResult.UserAgent = %q, want %q", highResult.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if highResult.OS != "MacOS" || highResult.Arch != "arm64" {
t.Fatalf("highResult platform = %s/%s, want %s/%s", highResult.OS, highResult.Arch, "MacOS", "arm64")
}
cached := resolveClaudeDeviceProfile(auth, "key-racy-upgrade", http.Header{
"User-Agent": []string{"curl/8.7.1"},
}, cfg)
if cached.UserAgent != "claude-cli/2.1.63 (external, cli)" {
t.Fatalf("cached.UserAgent = %q, want %q", cached.UserAgent, "claude-cli/2.1.63 (external, cli)")
}
if cached.PackageVersion != "0.75.0" {
t.Fatalf("cached.PackageVersion = %q, want %q", cached.PackageVersion, "0.75.0")
}
if cached.OS != "MacOS" || cached.Arch != "arm64" {
t.Fatalf("cached platform = %s/%s, want %s/%s", cached.OS, cached.Arch, "MacOS", "arm64")
}
}
func TestApplyClaudeHeaders_ThirdPartyBaselineThenOfficialUpgradeKeepsPinnedPlatform(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := true
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.70 (external, cli)",
PackageVersion: "0.80.0",
RuntimeVersion: "v24.5.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-third-party-then-official",
Attributes: map[string]string{
"api_key": "key-third-party-then-official",
},
}
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-third-party-then-official", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.70 (external, cli)", "0.80.0", "v24.5.0", "MacOS", "arm64")
officialReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.77 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.87.0"},
"X-Stainless-Runtime-Version": []string{"v24.8.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(officialReq, auth, "key-third-party-then-official", false, nil, cfg)
assertClaudeFingerprint(t, officialReq.Header, "claude-cli/2.1.77 (external, cli)", "0.87.0", "v24.8.0", "MacOS", "arm64")
}
func TestApplyClaudeHeaders_DisableDeviceProfileStabilization(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-disable-stability",
Attributes: map[string]string{
"api_key": "key-disable-stability",
},
}
firstReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(firstReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, firstReq.Header, "claude-cli/2.1.62 (external, cli)", "0.74.0", "v24.3.0", "Linux", "x64")
thirdPartyReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"lobe-chat/1.0"},
"X-Stainless-Package-Version": []string{"0.10.0"},
"X-Stainless-Runtime-Version": []string{"v18.0.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(thirdPartyReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, thirdPartyReq.Header, "claude-cli/2.1.60 (external, cli)", "0.10.0", "v18.0.0", "Windows", "x64")
lowerReq := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.61 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.73.0"},
"X-Stainless-Runtime-Version": []string{"v24.2.0"},
"X-Stainless-Os": []string{"Windows"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(lowerReq, auth, "key-disable-stability", false, nil, cfg)
assertClaudeFingerprint(t, lowerReq.Header, "claude-cli/2.1.61 (external, cli)", "0.73.0", "v24.2.0", "Windows", "x64")
}
func TestApplyClaudeHeaders_LegacyModePreservesConfiguredUserAgentOverrideForClaudeClients(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-legacy-ua-override",
Attributes: map[string]string{
"api_key": "key-legacy-ua-override",
"header:User-Agent": "config-ua/1.0",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"claude-cli/2.1.62 (external, cli)"},
"X-Stainless-Package-Version": []string{"0.74.0"},
"X-Stainless-Runtime-Version": []string{"v24.3.0"},
"X-Stainless-Os": []string{"Linux"},
"X-Stainless-Arch": []string{"x64"},
})
applyClaudeHeaders(req, auth, "key-legacy-ua-override", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "config-ua/1.0", "0.74.0", "v24.3.0", "Linux", "x64")
}
func TestApplyClaudeHeaders_LegacyModeFallsBackToRuntimeOSArchWhenMissing(t *testing.T) {
resetClaudeDeviceProfileCache()
stabilize := false
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
StabilizeDeviceProfile: &stabilize,
},
}
auth := &cliproxyauth.Auth{
ID: "auth-legacy-runtime-os-arch",
Attributes: map[string]string{
"api_key": "key-legacy-runtime-os-arch",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
})
applyClaudeHeaders(req, auth, "key-legacy-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
}
func TestApplyClaudeHeaders_UnsetStabilizationAlsoUsesLegacyRuntimeOSArchFallback(t *testing.T) {
resetClaudeDeviceProfileCache()
cfg := &config.Config{
ClaudeHeaderDefaults: config.ClaudeHeaderDefaults{
UserAgent: "claude-cli/2.1.60 (external, cli)",
PackageVersion: "0.70.0",
RuntimeVersion: "v22.0.0",
OS: "MacOS",
Arch: "arm64",
},
}
auth := &cliproxyauth.Auth{
ID: "auth-unset-runtime-os-arch",
Attributes: map[string]string{
"api_key": "key-unset-runtime-os-arch",
},
}
req := newClaudeHeaderTestRequest(t, http.Header{
"User-Agent": []string{"curl/8.7.1"},
})
applyClaudeHeaders(req, auth, "key-unset-runtime-os-arch", false, nil, cfg)
assertClaudeFingerprint(t, req.Header, "claude-cli/2.1.60 (external, cli)", "0.70.0", "v22.0.0", mapStainlessOS(), mapStainlessArch())
}
func TestClaudeDeviceProfileStabilizationEnabled_DefaultFalse(t *testing.T) {
if claudeDeviceProfileStabilizationEnabled(nil) {
t.Fatal("expected nil config to default to disabled stabilization")
}
if claudeDeviceProfileStabilizationEnabled(&config.Config{}) {
t.Fatal("expected unset stabilize-device-profile to default to disabled stabilization")
}
}
func TestApplyClaudeToolPrefix(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")

View File

@@ -0,0 +1,343 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
)
const (
codeBuddyChatPath = "/v2/chat/completions"
codeBuddyAuthType = "codebuddy"
)
// CodeBuddyExecutor handles requests to the CodeBuddy API.
type CodeBuddyExecutor struct {
cfg *config.Config
}
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
return &CodeBuddyExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
// codeBuddyCredentials extracts the access token and domain from auth metadata.
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
if auth == nil {
return "", "", ""
}
accessToken = metaStringValue(auth.Metadata, "access_token")
userID = metaStringValue(auth.Metadata, "user_id")
domain = metaStringValue(auth.Metadata, "domain")
if domain == "" {
domain = codebuddy.DefaultDomain
}
return
}
// PrepareRequest prepares the HTTP request before execution.
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return fmt.Errorf("codebuddy: missing access token")
}
e.applyHeaders(req, accessToken, userID, domain)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("codebuddy executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, userID, domain := codeBuddyCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("codebuddy: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := codebuddy.BaseURL + codeBuddyChatPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, maxScannerBufferSize)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh exchanges the CodeBuddy refresh token for a new access token.
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("codebuddy: missing auth")
}
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
if refreshToken == "" {
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
return auth, nil
}
accessToken, userID, domain := codeBuddyCredentials(auth)
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
if err != nil {
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
}
updated := auth.Clone()
updated.Metadata["access_token"] = storage.AccessToken
if storage.RefreshToken != "" {
updated.Metadata["refresh_token"] = storage.RefreshToken
}
updated.Metadata["expires_in"] = storage.ExpiresIn
updated.Metadata["domain"] = storage.Domain
if storage.UserID != "" {
updated.Metadata["user_id"] = storage.UserID
}
now := time.Now()
updated.UpdatedAt = now
updated.LastRefreshedAt = now
return updated, nil
}
// CountTokens is not supported for CodeBuddy.
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
}
// applyHeaders sets required headers for CodeBuddy API requests.
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", codebuddy.UserAgent)
req.Header.Set("X-User-Id", userID)
req.Header.Set("X-Domain", domain)
req.Header.Set("X-Product", "SaaS")
req.Header.Set("X-IDE-Type", "CLI")
req.Header.Set("X-IDE-Name", "CLI")
req.Header.Set("X-IDE-Version", "2.63.2")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}

View File

@@ -0,0 +1,125 @@
package executor
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type codexContinuity struct {
Key string
Source string
}
func metadataString(meta map[string]any, key string) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[key]
if !ok || raw == nil {
return ""
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func principalString(raw any) string {
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case fmt.Stringer:
return strings.TrimSpace(v.String())
default:
return strings.TrimSpace(fmt.Sprintf("%v", raw))
}
}
func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) codexContinuity {
if promptCacheKey := strings.TrimSpace(gjson.GetBytes(req.Payload, "prompt_cache_key").String()); promptCacheKey != "" {
return codexContinuity{Key: promptCacheKey, Source: "prompt_cache_key"}
}
if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" {
return codexContinuity{Key: executionSession, Source: "execution_session"}
}
if ginCtx := ginContextFrom(ctx); ginCtx != nil {
if ginCtx.Request != nil {
if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" {
return codexContinuity{Key: v, Source: "idempotency_key"}
}
}
if v, exists := ginCtx.Get("apiKey"); exists && v != nil {
if trimmed := principalString(v); trimmed != "" {
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+trimmed)).String(), Source: "client_principal"}
}
}
}
if auth != nil {
if authID := strings.TrimSpace(auth.ID); authID != "" {
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:"+authID)).String(), Source: "auth_id"}
}
}
return codexContinuity{}
}
func applyCodexContinuityBody(rawJSON []byte, continuity codexContinuity) []byte {
if continuity.Key == "" {
return rawJSON
}
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", continuity.Key)
return rawJSON
}
func applyCodexContinuityHeaders(headers http.Header, continuity codexContinuity) {
if headers == nil || continuity.Key == "" {
return
}
headers.Set("session_id", continuity.Key)
}
func logCodexRequestDiagnostics(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, headers http.Header, body []byte, continuity codexContinuity) {
if !log.IsLevelEnabled(log.DebugLevel) {
return
}
entry := logWithRequestID(ctx)
authID := ""
authFile := ""
if auth != nil {
authID = strings.TrimSpace(auth.ID)
authFile = strings.TrimSpace(auth.FileName)
}
selectedAuthID := metadataString(opts.Metadata, cliproxyexecutor.SelectedAuthMetadataKey)
executionSessionID := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey)
entry.Debugf(
"codex request diagnostics auth_id=%s selected_auth_id=%s auth_file=%s exec_session=%s continuity_source=%s session_id=%s prompt_cache_key=%s prompt_cache_retention=%s store=%t has_instructions=%t reasoning_effort=%s reasoning_summary=%s chatgpt_account_id=%t originator=%s model=%s source_format=%s",
authID,
selectedAuthID,
authFile,
executionSessionID,
continuity.Source,
strings.TrimSpace(headers.Get("session_id")),
gjson.GetBytes(body, "prompt_cache_key").String(),
gjson.GetBytes(body, "prompt_cache_retention").String(),
gjson.GetBytes(body, "store").Bool(),
gjson.GetBytes(body, "instructions").Exists(),
gjson.GetBytes(body, "reasoning.effort").String(),
gjson.GetBytes(body, "reasoning.summary").String(),
strings.TrimSpace(headers.Get("Chatgpt-Account-Id")) != "",
strings.TrimSpace(headers.Get("Originator")),
req.Model,
opts.SourceFormat.String(),
)
}

View File

@@ -28,8 +28,8 @@ import (
)
const (
codexClientVersion = "0.101.0"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexUserAgent = "codex_cli_rs/0.116.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexOriginator = "codex_cli_rs"
)
var dataTag = []byte("data:")
@@ -111,18 +111,19 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -183,7 +184,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -222,11 +223,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
body, _ = sjson.DeleteBytes(body, "stream")
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -273,7 +275,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -309,19 +311,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
if err != nil {
return nil, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -387,7 +390,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -415,6 +418,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
@@ -432,7 +436,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
@@ -596,8 +600,9 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
return auth, nil
}
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, url string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) (*http.Request, codexContinuity, error) {
var cache codexCache
continuity := codexContinuity{}
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
@@ -610,30 +615,26 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
}
setCodexCache(key, cache)
}
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
}
} else if from == "openai-response" {
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
if promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
}
} else if from == "openai" {
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
}
continuity = resolveCodexContinuity(ctx, auth, req, opts)
cache.ID = continuity.Key
}
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
}
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
if err != nil {
return nil, err
return nil, continuity, err
}
if cache.ID != "" {
httpReq.Header.Set("Conversation_id", cache.ID)
httpReq.Header.Set("Session_id", cache.ID)
}
return httpReq, nil
applyCodexContinuityHeaders(httpReq.Header, continuity)
return httpReq, continuity, nil
}
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
@@ -645,8 +646,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
misc.EnsureHeader(r.Header, ginHeaders, "session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
@@ -663,8 +666,12 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
isAPIKey = true
}
}
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
r.Header.Set("Originator", originator)
} else if !isAPIKey {
r.Header.Set("Originator", codexOriginator)
}
if !isAPIKey {
r.Header.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
r.Header.Set("Chatgpt-Account-Id", accountID)
@@ -679,13 +686,39 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
errCode := statusCode
if isCodexModelCapacityError(body) {
errCode = http.StatusTooManyRequests
}
err := statusErr{code: errCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter
}
return err
}
func isCodexModelCapacityError(errorBody []byte) bool {
if len(errorBody) == 0 {
return false
}
candidates := []string{
gjson.GetBytes(errorBody, "error.message").String(),
gjson.GetBytes(errorBody, "message").String(),
string(errorBody),
}
for _, candidate := range candidates {
lower := strings.ToLower(strings.TrimSpace(candidate))
if lower == "" {
continue
}
if strings.Contains(lower, "selected model is at capacity") ||
strings.Contains(lower, "model is at capacity. please try a different model") {
return true
}
}
return false
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil

View File

@@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
@@ -27,7 +28,7 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
}
url := "https://example.com/responses"
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
httpReq, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
@@ -42,14 +43,14 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
}
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
if gotSession := httpReq.Header.Get("session_id"); gotSession != expectedKey {
t.Fatalf("session_id = %q, want %q", gotSession, expectedKey)
}
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
if got := httpReq.Header.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
}
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
httpReq2, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error (second call): %v", err)
}
@@ -62,3 +63,118 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
}
}
func TestCodexExecutorCacheHelper_OpenAIResponses_PreservesPromptCacheRetention(t *testing.T) {
executor := &CodexExecutor{}
url := "https://example.com/responses"
req := cliproxyexecutor.Request{
Model: "gpt-5.3-codex",
Payload: []byte(`{"model":"gpt-5.3-codex","prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
}
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true,"prompt_cache_retention":"persistent"}`)
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai-response"), url, req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "cache-key-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
}
if got := gjson.GetBytes(body, "prompt_cache_retention").String(); got != "persistent" {
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
}
if got := httpReq.Header.Get("session_id"); got != "cache-key-1" {
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
}
if got := httpReq.Header.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
}
}
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_UsesExecutionSessionForContinuity(t *testing.T) {
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4"}`),
}
opts := cliproxyexecutor.Options{Metadata: map[string]any{cliproxyexecutor.ExecutionSessionMetadataKey: "exec-session-1"}}
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai"), "https://example.com/responses", req, opts, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "exec-session-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "exec-session-1")
}
if got := httpReq.Header.Get("session_id"); got != "exec-session-1" {
t.Fatalf("session_id = %q, want %q", got, "exec-session-1")
}
}
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID(t *testing.T) {
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.4",
Payload: []byte(`{"model":"gpt-5.4"}`),
}
auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"}
httpReq, _, err := executor.cacheHelper(context.Background(), auth, sdktranslator.FromString("openai"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
expected := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:codex-auth-1")).String()
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != expected {
t.Fatalf("prompt_cache_key = %q, want %q", got, expected)
}
if got := httpReq.Header.Get("session_id"); got != expected {
t.Fatalf("session_id = %q, want %q", got, expected)
}
}
func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) {
executor := &CodexExecutor{}
req := cliproxyexecutor.Request{
Model: "claude-3-7-sonnet",
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
}
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
if continuity.Key == "" {
t.Fatal("continuity.Key = empty, want non-empty")
}
body, err := io.ReadAll(httpReq.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key {
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
}
if got := httpReq.Header.Get("session_id"); got != continuity.Key {
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
}
}

View File

@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
})
}
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
err := newCodexStatusErr(http.StatusBadRequest, body)
if got := err.StatusCode(); got != http.StatusTooManyRequests {
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
}
if err.RetryAfter() != nil {
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
}
}
func itoa(v int64) string {
return strconv.FormatInt(v, 10)
}

View File

@@ -178,7 +178,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
@@ -190,7 +189,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, err
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
@@ -209,6 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
wsReqBody := buildCodexWebsocketRequestBody(body)
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -343,7 +343,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: out}
return resp, nil
}
}
@@ -385,7 +385,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, err
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
@@ -403,6 +403,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
}
wsReqBody := buildCodexWebsocketRequestBody(body)
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -592,7 +593,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
line := encodeCodexWebsocketAsSSE(payload)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, &param)
for i := range chunks {
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
terminateReason = "context_done"
terminateErr = ctx.Err()
return
@@ -761,13 +762,14 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
return parsed.String(), nil
}
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
headers := http.Header{}
if len(rawJSON) == 0 {
return rawJSON, headers
return rawJSON, headers, codexContinuity{}
}
var cache codexCache
continuity := codexContinuity{}
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
@@ -781,20 +783,22 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
}
setCodexCache(key, cache)
}
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
}
} else if from == "openai-response" {
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
}
} else if from == "openai" {
continuity = resolveCodexContinuity(ctx, auth, req, opts)
cache.ID = continuity.Key
}
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
headers.Set("Conversation_id", cache.ID)
headers.Set("Session_id", cache.ID)
}
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
applyCodexContinuityHeaders(headers, continuity)
return rawJSON, headers
return rawJSON, headers, continuity
}
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
@@ -814,9 +818,10 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
@@ -825,7 +830,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
isAPIKey := false
@@ -834,8 +839,12 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
isAPIKey = true
}
}
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
headers.Set("Originator", codexOriginator)
}
if !isAPIKey {
headers.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {

View File

@@ -9,7 +9,9 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -32,6 +34,49 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
}
func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T) {
req := cliproxyexecutor.Request{
Model: "gpt-5-codex",
Payload: []byte(`{"prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
}
body := []byte(`{"model":"gpt-5-codex","stream":true,"prompt_cache_retention":"persistent"}`)
updatedBody, headers, _ := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("openai-response"), req, cliproxyexecutor.Options{}, body)
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != "cache-key-1" {
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
}
if got := gjson.GetBytes(updatedBody, "prompt_cache_retention").String(); got != "persistent" {
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
}
if got := headers.Get("session_id"); got != "cache-key-1" {
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
}
if got := headers.Get("Conversation_id"); got != "" {
t.Fatalf("Conversation_id = %q, want empty", got)
}
}
func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) {
req := cliproxyexecutor.Request{
Model: "claude-3-7-sonnet",
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
}
body := []byte(`{"model":"gpt-5.4","stream":true}`)
updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body)
if continuity.Key == "" {
t.Fatal("continuity.Key = empty, want non-empty")
}
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key {
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
}
if got := headers.Get("session_id"); got != continuity.Key {
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
@@ -41,9 +86,46 @@ func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := headers.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := headers.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
@@ -177,6 +259,57 @@ func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
}
}
func TestApplyCodexHeadersPassesThroughClientIdentityHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
}))
applyCodexHeaders(req, auth, "oauth-token", true, nil)
if got := req.Header.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := req.Header.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != `{"turn_id":"turn-1"}` {
t.Fatalf("X-Codex-Turn-Metadata = %s, want %s", got, `{"turn_id":"turn-1"}`)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
}
func TestApplyCodexHeadersDoesNotInjectClientOnlyHeadersByDefault(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
applyCodexHeaders(req, nil, "oauth-token", true, nil)
if got := req.Header.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
}
if got := req.Header.Get("X-Codex-Turn-Metadata"); got != "" {
t.Fatalf("X-Codex-Turn-Metadata = %q, want empty", got)
}
if got := req.Header.Get("X-Client-Request-Id"); got != "" {
t.Fatalf("X-Client-Request-Id = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()

File diff suppressed because it is too large Load Diff

View File

@@ -224,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -401,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -430,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
@@ -544,7 +544,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
@@ -811,18 +811,18 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}

View File

@@ -205,7 +205,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -321,12 +321,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -415,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: translated, Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
@@ -527,18 +527,18 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
emptyImagePart := []byte(`{"inlineData":{"mime_type":"image/png","data":""}}`)
emptyImagePart, _ = sjson.SetBytes(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := []byte(`[]`)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(`{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`))
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
newPartsJson, _ = sjson.SetRawBytes(newPartsJson, "-1", []byte(parts[j].Raw))
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", newPartsJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}

View File

@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,7 +524,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -636,12 +636,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -760,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -941,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -221,13 +221,13 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
}
var param any
converted := ""
var converted []byte
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: converted}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -374,14 +374,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}
var chunks []string
var chunks [][]byte
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
}
}
@@ -577,9 +577,33 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model
return true
}
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
if info := registry.GetGlobalRegistry().GetModelInfo(baseModel, githubCopilotAuthType); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
if info := lookupGitHubCopilotStaticModelInfo(baseModel); info != nil {
return len(info.SupportedEndpoints) > 0 && !containsEndpoint(info.SupportedEndpoints, githubCopilotChatPath) && containsEndpoint(info.SupportedEndpoints, githubCopilotResponsesPath)
}
return strings.Contains(baseModel, "codex")
}
func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
for _, info := range registry.GetStaticModelDefinitionsByChannel(githubCopilotAuthType) {
if info != nil && strings.EqualFold(info.ID, model) {
return info
}
}
return nil
}
func containsEndpoint(endpoints []string, endpoint string) bool {
for _, item := range endpoints {
if item == endpoint {
return true
}
}
return false
}
// flattenAssistantContent converts assistant message content from array format
// to a joined string. GitHub Copilot requires assistant content as a string;
// sending it as an array causes Claude models to re-answer all previous prompts.
@@ -977,7 +1001,7 @@ type githubCopilotResponsesStreamState struct {
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) []byte {
root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
@@ -1067,10 +1091,10 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
return out
return []byte(out)
}
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) [][]byte {
if *param == nil {
*param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1,
@@ -1092,7 +1116,10 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4)
results := make([][]byte, 0, 4)
appendResult := func(chunk string) {
results = append(results, []byte(chunk))
}
ensureMessageStart := func() {
if state.MessageStarted {
return
@@ -1100,7 +1127,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
appendResult("event: message_start\ndata: " + messageStart + "\n\n")
state.MessageStarted = true
}
startTextBlockIfNeeded := func() {
@@ -1113,7 +1140,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
state.TextBlockStarted = true
}
stopTextBlockIfNeeded := func() {
@@ -1122,7 +1149,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
state.TextBlockStarted = false
state.TextBlockIndex = -1
}
@@ -1152,7 +1179,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + contentDelta + "\n\n")
}
case "response.reasoning_summary_part.added":
ensureMessageStart()
@@ -1161,7 +1188,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
state.NextContentIndex++
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
appendResult("event: content_block_start\ndata: " + thinkingStart + "\n\n")
case "response.reasoning_summary_text.delta":
if state.ReasoningActive {
delta := gjson.GetBytes(payload, "delta").String()
@@ -1169,14 +1196,14 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + thinkingDelta + "\n\n")
}
}
case "response.reasoning_summary_part.done":
if state.ReasoningActive {
thinkingStop := `{"type":"content_block_stop","index":0}`
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + thinkingStop + "\n\n")
state.ReasoningActive = false
}
case "response.output_item.added":
@@ -1204,7 +1231,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
appendResult("event: content_block_start\ndata: " + contentBlockStart + "\n\n")
case "response.output_item.delta":
item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" {
@@ -1224,7 +1251,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.function_call_arguments.delta":
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
@@ -1241,7 +1268,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
appendResult("event: content_block_delta\ndata: " + inputDelta + "\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
@@ -1252,7 +1279,7 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
appendResult("event: content_block_stop\ndata: " + contentBlockStop + "\n\n")
case "response.completed":
ensureMessageStart()
stopTextBlockIfNeeded()
@@ -1276,8 +1303,8 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
if cachedTokens > 0 {
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
}
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
appendResult("event: message_delta\ndata: " + messageDelta + "\n\n")
appendResult("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true
}
}

View File

@@ -5,6 +5,7 @@ import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -70,6 +71,29 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
}
}
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
t.Parallel()
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{{
ID: "gpt-5.4",
SupportedEndpoints: []string{"/chat/completions", "/responses"},
}})
defer reg.UnregisterClient(clientID)
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {

View File

@@ -30,12 +30,20 @@ const (
gitLabChatEndpoint = "/api/v4/chat/completions"
gitLabCodeSuggestionsEndpoint = "/api/v4/code_suggestions/completions"
gitLabSSEStreamingHeader = "X-Supports-Sse-Streaming"
gitLabContext1MBeta = "context-1m-2025-08-07"
gitLabNativeUserAgent = "CLIProxyAPIPlus/GitLab-Duo"
)
type GitLabExecutor struct {
cfg *config.Config
}
type gitLabCatalogModel struct {
ID string
DisplayName string
Provider string
}
type gitLabPrompt struct {
Instruction string
FileName string
@@ -53,6 +61,23 @@ type gitLabOpenAIStreamState struct {
Finished bool
}
var gitLabAgenticCatalog = []gitLabCatalogModel{
{ID: "duo-chat-gpt-5-1", DisplayName: "GitLab Duo (GPT-5.1)", Provider: "openai"},
{ID: "duo-chat-opus-4-6", DisplayName: "GitLab Duo (Claude Opus 4.6)", Provider: "anthropic"},
{ID: "duo-chat-opus-4-5", DisplayName: "GitLab Duo (Claude Opus 4.5)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-6", DisplayName: "GitLab Duo (Claude Sonnet 4.6)", Provider: "anthropic"},
{ID: "duo-chat-sonnet-4-5", DisplayName: "GitLab Duo (Claude Sonnet 4.5)", Provider: "anthropic"},
{ID: "duo-chat-gpt-5-mini", DisplayName: "GitLab Duo (GPT-5 Mini)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2", DisplayName: "GitLab Duo (GPT-5.2)", Provider: "openai"},
{ID: "duo-chat-gpt-5-2-codex", DisplayName: "GitLab Duo (GPT-5.2 Codex)", Provider: "openai"},
{ID: "duo-chat-gpt-5-codex", DisplayName: "GitLab Duo (GPT-5 Codex)", Provider: "openai"},
{ID: "duo-chat-haiku-4-5", DisplayName: "GitLab Duo (Claude Haiku 4.5)", Provider: "anthropic"},
}
var gitLabModelAliases = map[string]string{
"duo-chat-haiku-4-6": "duo-chat-haiku-4-5",
}
func NewGitLabExecutor(cfg *config.Config) *GitLabExecutor {
return &GitLabExecutor{cfg: cfg}
}
@@ -249,12 +274,12 @@ func (e *GitLabExecutor) nativeGateway(
auth *cliproxyauth.Auth,
req cliproxyexecutor.Request,
) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth, cliproxyexecutor.Request, bool) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, req.Model); ok {
nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewClaudeExecutor(e.cfg), nativeAuth, nativeReq, true
}
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, req.Model); ok {
nativeReq := req
nativeReq.Model = gitLabResolvedModel(auth, req.Model)
return NewCodexExecutor(e.cfg), nativeAuth, nativeReq, true
@@ -263,10 +288,10 @@ func (e *GitLabExecutor) nativeGateway(
}
func (e *GitLabExecutor) nativeGatewayHTTP(auth *cliproxyauth.Auth) (cliproxyauth.ProviderExecutor, *cliproxyauth.Auth) {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabAnthropicGatewayAuth(auth, ""); ok {
return NewClaudeExecutor(e.cfg), nativeAuth
}
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth); ok {
if nativeAuth, ok := buildGitLabOpenAIGatewayAuth(auth, ""); ok {
return NewCodexExecutor(e.cfg), nativeAuth
}
return nil, nil
@@ -664,7 +689,7 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
if auth != nil {
util.ApplyCustomHeadersFromAttrs(req, auth.Attributes)
}
for key, value := range gitLabGatewayHeaders(auth) {
for key, value := range gitLabGatewayHeaders(auth, "") {
if key == "" || value == "" {
continue
}
@@ -672,34 +697,40 @@ func applyGitLabRequestHeaders(req *http.Request, auth *cliproxyauth.Auth) {
}
}
func gitLabGatewayHeaders(auth *cliproxyauth.Auth) map[string]string {
if auth == nil || auth.Metadata == nil {
return nil
}
raw, ok := auth.Metadata["duo_gateway_headers"]
if !ok {
return nil
}
func gitLabGatewayHeaders(auth *cliproxyauth.Auth, targetProvider string) map[string]string {
out := make(map[string]string)
switch typed := raw.(type) {
case map[string]string:
for key, value := range typed {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" && value != "" {
out[key] = value
if auth != nil && auth.Metadata != nil {
raw, ok := auth.Metadata["duo_gateway_headers"]
if ok {
switch typed := raw.(type) {
case map[string]string:
for key, value := range typed {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key != "" && value != "" {
out[key] = value
}
}
case map[string]any:
for key, value := range typed {
key = strings.TrimSpace(key)
if key == "" {
continue
}
strValue := strings.TrimSpace(fmt.Sprint(value))
if strValue != "" {
out[key] = strValue
}
}
}
}
case map[string]any:
for key, value := range typed {
key = strings.TrimSpace(key)
if key == "" {
continue
}
strValue := strings.TrimSpace(fmt.Sprint(value))
if strValue != "" {
out[key] = strValue
}
}
if _, ok := out["User-Agent"]; !ok {
out["User-Agent"] = gitLabNativeUserAgent
}
if strings.EqualFold(strings.TrimSpace(targetProvider), "openai") {
if _, ok := out["anthropic-beta"]; !ok {
out["anthropic-beta"] = gitLabContext1MBeta
}
}
if len(out) == 0 {
@@ -989,8 +1020,8 @@ func gitLabUsage(model string, translatedReq []byte, text string) (int64, int64)
return promptTokens, int64(completionCount)
}
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
if !gitLabUsesAnthropicGateway(auth) {
func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesAnthropicGateway(auth, requestedModel) {
return nil, false
}
baseURL := gitLabAnthropicGatewayBaseURL(auth)
@@ -1006,7 +1037,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
}
nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) {
nativeAuth.Attributes["gitlab_duo_force_context_1m"] = "true"
for key, value := range gitLabGatewayHeaders(auth, "anthropic") {
if key == "" || value == "" {
continue
}
@@ -1015,8 +1047,8 @@ func buildGitLabAnthropicGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Aut
return nativeAuth, true
}
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, bool) {
if !gitLabUsesOpenAIGateway(auth) {
func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth, requestedModel string) (*cliproxyauth.Auth, bool) {
if !gitLabUsesOpenAIGateway(auth, requestedModel) {
return nil, false
}
baseURL := gitLabOpenAIGatewayBaseURL(auth)
@@ -1032,7 +1064,7 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
}
nativeAuth.Attributes["api_key"] = token
nativeAuth.Attributes["base_url"] = baseURL
for key, value := range gitLabGatewayHeaders(auth) {
for key, value := range gitLabGatewayHeaders(auth, "openai") {
if key == "" || value == "" {
continue
}
@@ -1041,34 +1073,41 @@ func buildGitLabOpenAIGatewayAuth(auth *cliproxyauth.Auth) (*cliproxyauth.Auth,
return nativeAuth, true
}
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth) bool {
func gitLabUsesAnthropicGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil {
return false
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
provider := gitLabGatewayProvider(auth, requestedModel)
return provider == "anthropic" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
}
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth) bool {
func gitLabUsesOpenAIGateway(auth *cliproxyauth.Auth, requestedModel string) bool {
if auth == nil || auth.Metadata == nil {
return false
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
modelName := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_name"))
provider = inferGitLabProviderFromModel(modelName)
}
provider := gitLabGatewayProvider(auth, requestedModel)
return provider == "openai" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_base_url") != "" &&
gitLabMetadataString(auth.Metadata, "duo_gateway_token") != ""
}
func gitLabGatewayProvider(auth *cliproxyauth.Auth, requestedModel string) string {
modelName := strings.TrimSpace(gitLabResolvedModel(auth, requestedModel))
if provider := inferGitLabProviderFromModel(modelName); provider != "" {
return provider
}
if auth == nil || auth.Metadata == nil {
return ""
}
provider := strings.ToLower(gitLabMetadataString(auth.Metadata, "model_provider"))
if provider == "" {
provider = inferGitLabProviderFromModel(gitLabMetadataString(auth.Metadata, "model_name"))
}
return provider
}
func inferGitLabProviderFromModel(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
switch {
@@ -1151,6 +1190,9 @@ func gitLabBaseURL(auth *cliproxyauth.Auth) string {
func gitLabResolvedModel(auth *cliproxyauth.Auth, requested string) string {
requested = strings.TrimSpace(thinking.ParseSuffix(requested).ModelName)
if requested != "" && !strings.EqualFold(requested, "gitlab-duo") {
if mapped, ok := gitLabModelAliases[strings.ToLower(requested)]; ok && strings.TrimSpace(mapped) != "" {
return mapped
}
return requested
}
if auth != nil && auth.Metadata != nil {
@@ -1277,8 +1319,8 @@ func gitLabAuthKind(method string) string {
}
func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
models := make([]*registry.ModelInfo, 0, 4)
seen := make(map[string]struct{}, 4)
models := make([]*registry.ModelInfo, 0, len(gitLabAgenticCatalog)+4)
seen := make(map[string]struct{}, len(gitLabAgenticCatalog)+4)
addModel := func(id, displayName, provider string) {
id = strings.TrimSpace(id)
if id == "" {
@@ -1302,6 +1344,18 @@ func GitLabModelsFromAuth(auth *cliproxyauth.Auth) []*registry.ModelInfo {
}
addModel("gitlab-duo", "GitLab Duo", "gitlab")
for _, model := range gitLabAgenticCatalog {
addModel(model.ID, model.DisplayName, model.Provider)
}
for alias, upstream := range gitLabModelAliases {
target := strings.TrimSpace(upstream)
displayName := "GitLab Duo Alias"
provider := strings.TrimSpace(inferGitLabProviderFromModel(target))
if provider != "" {
displayName = fmt.Sprintf("GitLab Duo Alias (%s)", provider)
}
addModel(alias, displayName, provider)
}
if auth == nil {
return models
}

View File

@@ -217,6 +217,69 @@ func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
}
}
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
var gotPath string
var gotModel string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuthHeader = r.Header.Get("Authorization")
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
gotBetaHeader = r.Header.Get("anthropic-beta")
gotUserAgent = r.Header.Get("User-Agent")
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
}))
defer srv.Close()
exec := NewGitLabExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "gitlab",
Metadata: map[string]any{
"duo_gateway_base_url": srv.URL,
"duo_gateway_token": "gateway-token",
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
}
req := cliproxyexecutor.Request{
Model: "duo-chat-gpt-5-codex",
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
}
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/v1/proxy/openai/v1/responses" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
}
if gotAuthHeader != "Bearer gateway-token" {
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
}
if gotRealmHeader != "saas" {
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
}
if gotBetaHeader != gitLabContext1MBeta {
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if gotModel != "duo-chat-gpt-5-codex" {
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
}
}
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -251,13 +314,12 @@ func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
ID: "gitlab-auth.json",
Provider: "gitlab",
Metadata: map[string]any{
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"oauth_client_secret": "client-secret",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
"base_url": srv.URL,
"access_token": "oauth-access",
"refresh_token": "oauth-refresh",
"oauth_client_id": "client-id",
"auth_method": "oauth",
"oauth_expires_at": "2000-01-01T00:00:00Z",
},
}
@@ -397,9 +459,11 @@ func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
}
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
var gotPath string
var gotPath, gotBetaHeader, gotUserAgent string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotBetaHeader = r.Header.Get("Anthropic-Beta")
gotUserAgent = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: message_start\n"))
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
@@ -441,6 +505,12 @@ func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
if gotPath != "/v1/proxy/anthropic/v1/messages" {
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
}
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
}
if gotUserAgent != gitLabNativeUserAgent {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
}
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
}

View File

@@ -169,7 +169,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -281,7 +281,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -315,7 +315,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.

View File

@@ -161,7 +161,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -271,12 +271,12 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)

View File

@@ -89,6 +89,13 @@ var endpointAliases = map[string]string{
"cli": "amazonq",
}
func enqueueTranslatedSSE(out chan<- cliproxyexecutor.StreamChunk, chunk []byte) {
if len(chunk) == 0 {
return
}
out <- cliproxyexecutor.StreamChunk{Payload: append(bytes.Clone(chunk), '\n', '\n')}
}
// retryConfig holds configuration for socket retry logic.
// Based on kiro2Api Python implementation patterns.
type retryConfig struct {
@@ -2573,9 +2580,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send tool input as delta
@@ -2583,18 +2588,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Close block
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
hasToolUses = true
@@ -2664,9 +2665,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
messageStartSent = true
}
@@ -2916,9 +2915,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
lastReportedOutputTokens = currentOutputTokens
@@ -2939,17 +2936,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
continue
@@ -2978,18 +2971,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Send thinking delta
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
accumulatedThinkingContent.WriteString(thinkingText)
}
@@ -2998,9 +2987,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isThinkingBlockOpen = false
}
@@ -3029,17 +3016,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
accumulatedThinkingContent.WriteString(processContent)
}
@@ -3058,9 +3041,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isThinkingBlockOpen = false
}
@@ -3071,18 +3052,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Send text delta
claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
// Close text block before entering thinking
@@ -3090,9 +3067,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3120,17 +3095,13 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3158,9 +3129,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3171,9 +3140,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send input_json_delta with the tool input
@@ -3186,9 +3153,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3197,9 +3162,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3239,9 +3202,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3254,9 +3215,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3264,9 +3223,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Accumulate for token counting
@@ -3298,9 +3255,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
isTextBlockOpen = false
}
@@ -3310,9 +3265,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
if tu.Input != nil {
@@ -3323,9 +3276,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
}
@@ -3333,9 +3284,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3522,9 +3471,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
}
@@ -3609,18 +3556,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// Send message_stop event separately
msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
enqueueTranslatedSSE(out, chunk)
}
// reporter.publish is called via defer
}

View File

@@ -172,7 +172,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -290,7 +290,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
if errScan := scanner.Err(); errScan != nil {
@@ -330,7 +330,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
usageJSON := buildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
}
// Refresh is a no-op for API-key based compatibility providers.

View File

@@ -305,7 +305,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -421,12 +421,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
@@ -461,7 +461,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: translated}, nil
}
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -73,17 +73,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
return
}
r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: failed,
Detail: detail,
})
usage.PublishRecord(ctx, r.buildRecord(detail, failed))
})
}
@@ -96,20 +86,39 @@ func (r *usageReporter) ensurePublished(ctx context.Context) {
return
}
r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: false,
Detail: usage.Detail{},
})
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false))
})
}
func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record {
if r == nil {
return usage.Record{Detail: detail, Failed: failed}
}
return usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Latency: r.latency(),
Failed: failed,
Detail: detail,
}
}
func (r *usageReporter) latency() time.Duration {
if r == nil || r.requestedAt.IsZero() {
return 0
}
latency := time.Since(r.requestedAt)
if latency < 0 {
return 0
}
return latency
}
func apiKeyFromContext(ctx context.Context) string {
if ctx == nil {
return ""

View File

@@ -1,6 +1,11 @@
package executor
import "testing"
import (
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
@@ -41,3 +46,19 @@ func TestParseOpenAIUsageResponses(t *testing.T) {
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
}
}
func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
reporter := &usageReporter{
provider: "openai",
model: "gpt-5.4",
requestedAt: time.Now().Add(-1500 * time.Millisecond),
}
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.Latency < time.Second {
t.Fatalf("latency = %v, want >= 1s", record.Latency)
}
if record.Latency > 3*time.Second {
t.Fatalf("latency = %v, want <= 3s", record.Latency)
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -39,35 +40,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
rawJSON := inputRawJSON
// system instruction
systemInstructionJSON := ""
var systemInstructionJSON []byte
hasSystemInstruction := false
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstructionJSON = `{"role":"user","parts":[]}`
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
partJSON := `{}`
partJSON := []byte(`{}`)
if systemPrompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
partJSON, _ = sjson.SetBytes(partJSON, "text", systemPrompt)
}
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", partJSON)
hasSystemInstruction = true
}
}
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
}
// contents
contentsJSON := "[]"
contentsJSON := []byte(`[]`)
hasContents := false
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
toolNameByID := make(map[string]string)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
@@ -83,8 +88,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if role == "assistant" {
role = "model"
}
clientContentJSON := `{"role":"","parts":[]}`
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
clientContentJSON := []byte(`{"role":"","parts":[]}`)
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
@@ -143,15 +148,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// Valid signature, send as thought block
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "thought", true)
if thinkingText != "" {
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
}
// Always include "text" field — Google Antigravity API requires it
// even for redacted thinking where the text is empty.
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "thought", true)
partJSON, _ = sjson.SetBytes(partJSON, "text", thinkingText)
if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", signature)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
// Skip empty text parts to avoid Gemini API error:
@@ -159,17 +164,21 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if prompt == "" {
continue
}
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
functionName := contentResult.Get("name").String()
functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String()
if functionID != "" && functionName != "" {
toolNameByID[functionID] = functionName
}
// Handle both object and string input formats
var argsRaw string
if argsResult.IsObject() {
@@ -183,138 +192,147 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
if argsRaw != "" {
partJSON := `{}`
partJSON := []byte(`{}`)
// Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
partJSON, _ = sjson.SetBytes(partJSON, "thoughtSignature", skipSentinel)
}
if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.id", functionID)
}
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON, _ = sjson.SetBytes(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRawBytes(partJSON, "functionCall.args", []byte(argsRaw))
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
funcName, ok := toolNameByID[toolCallID]
if !ok {
// Fallback: derive a semantic name from the ID by stripping
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
// Only use the raw ID as a last resort when the heuristic produces an empty string.
parts := strings.Split(toolCallID, "-")
if len(parts) > 2 {
funcName = strings.Join(parts[:len(parts)-2], "-")
}
if funcName == "" {
funcName = toolCallID
}
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
}
functionResponseResult := contentResult.Get("content")
functionResponseJSON := `{}`
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
functionResponseJSON := []byte(`{}`)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
responseData := ""
if functionResponseResult.Type == gjson.String {
responseData = functionResponseResult.String()
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
filteredJSON := []byte(`[]`)
imagePartsJSON := []byte(`[]`)
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
filteredJSON, _ = sjson.SetRawBytes(filteredJSON, "-1", []byte(fr.Raw))
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(lastNonImageRaw))
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
if gjson.GetBytes(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
imagePartJSON := []byte(`{}`)
imagePartJSON, _ = sjson.SetRawBytes(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := []byte(`[]`)
imagePartsJSON, _ = sjson.SetRawBytes(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.SetRawBytes(functionResponseJSON, "response.result", []byte(functionResponseResult.Raw))
} else {
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "response.result", "")
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRawBytes(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
sourceResult := contentResult.Get("source")
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
inlineDataJSON := []byte(`{}`)
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
inlineDataJSON, _ = sjson.SetBytes(inlineDataJSON, "data", data)
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
partJSON := []byte(`{}`)
partJSON, _ = sjson.SetRawBytes(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
}
}
}
// Reorder parts for 'model' role to ensure thinking block is first
if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts")
partsResult := gjson.GetBytes(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
@@ -336,7 +354,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
}
}
}
@@ -344,33 +362,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Skip messages with empty parts array to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
partsCheck := gjson.Get(clientContentJSON, "parts")
partsCheck := gjson.GetBytes(clientContentJSON, "parts")
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
continue
}
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
partJSON := `{}`
partJSON := []byte(`{}`)
if prompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt)
partJSON, _ = sjson.SetBytes(partJSON, "text", prompt)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
clientContentJSON, _ = sjson.SetRawBytes(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRawBytes(contentsJSON, "-1", clientContentJSON)
hasContents = true
}
}
}
// tools
toolsJSON := ""
var toolsJSON []byte
toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]`
toolsJSON = []byte(`[{"functionDeclarations":[]}]`)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
@@ -378,23 +396,24 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
// Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
for toolKey := range gjson.Parse(tool).Map() {
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
for toolKey := range gjson.ParseBytes(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) {
continue
}
tool, _ = sjson.Delete(tool, toolKey)
tool, _ = sjson.DeleteBytes(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++
}
}
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
out := []byte(`{"model":"","request":{"contents":[]}}`)
out, _ = sjson.SetBytes(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0
@@ -408,27 +427,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if hasSystemInstruction {
// Append hint as a new part to existing system instruction
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
} else {
// Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}`
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
systemInstructionJSON = []byte(`{"role":"user","parts":[]}`)
hintPart := []byte(`{"text":""}`)
hintPart, _ = sjson.SetBytes(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRawBytes(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true
}
}
if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstructionJSON)
}
if hasContents {
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
out, _ = sjson.SetRawBytes(out, "request.contents", contentsJSON)
}
if toolDeclCount > 0 {
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
out, _ = sjson.SetRawBytes(out, "request.tools", toolsJSON)
}
// tool_choice
@@ -445,15 +464,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
}
}
}
@@ -464,8 +483,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
@@ -477,28 +496,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", v.Num)
}
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
out = common.AttachDefaultSafetySettings(out, "request.safetySettings")
return outBytes
return out
}

View File

@@ -365,6 +365,17 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "get_weather-call-123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
@@ -382,13 +393,177 @@ func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
outputStr := string(output)
// Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Error("functionResponse should exist")
}
if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
}
if funcResp.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"name": "Glob",
"input": {"pattern": "**/*.py"}
},
{
"type": "tool_use",
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"name": "Bash",
"input": {"command": "ls"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "file1.py\nfile2.py"
},
{
"type": "tool_result",
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"content": "total 10"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp0.Exists() {
t.Fatal("first functionResponse should exist")
}
if got := funcResp0.Get("name").String(); got != "Glob" {
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
}
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
if !funcResp1.Exists() {
t.Fatal("second functionResponse should exist")
}
if got := funcResp1.Get("name").String(); got != "Bash" {
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "Read-1773420180464065165-1327",
"name": "Read",
"input": {"file_path": "/tmp/test.py"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-1773420180464065165-1327",
"content": "file content here"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "Read" {
t.Errorf("Expected name 'Read', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "get_weather" {
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "result data"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
got := funcResp.Get("name").String()
if got == "" {
t.Error("functionResponse.name must not be empty")
}
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
@@ -43,6 +44,10 @@ type Params struct {
// Signature caching support
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
// Reverse map: sanitized Gemini function name → original Claude tool name.
// Populated lazily on the first response chunk from the original request JSON.
ToolNameMap map[string]string
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -63,13 +68,14 @@ var toolUseIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of bytes, each containing a Claude Code-compatible SSE payload.
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
}
}
modelName := gjson.GetBytes(requestRawJSON, "model").String()
@@ -77,44 +83,44 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params := (*param).(*Params)
if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := ""
output := make([]byte, 0, 256)
// Only send final events if we have actually output content
if params.HasContent {
appendFinalEvents(params, &output, true)
return []string{
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
output = translatorcommon.AppendSSEEventString(output, "message_stop", `{"type":"message_stop"}`, 3)
return [][]byte{output}
}
return []string{}
return [][]byte{}
}
output := ""
output := make([]byte, 0, 1024)
appendEvent := func(event, payload string) {
output = translatorcommon.AppendSSEEventString(output, event, payload, 3)
}
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !params.HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
messageStartTemplate := []byte(`{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`)
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
messageStartTemplate, _ = sjson.SetBytes(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
appendEvent("message_start", string(messageStartTemplate))
params.HasFirstResponse = true
}
@@ -144,15 +150,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex)), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else {
// Transition from another state to thinking
@@ -163,19 +167,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex)), "delta.thinking", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.ResponseType = 2 // Set state to thinking
params.HasContent = true
// Start accumulating thinking text for signature caching
@@ -188,9 +187,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if params.ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.HasContent = true
} else {
// Transition from another state to text content
@@ -201,19 +199,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
if partTextResult.String() != "" {
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
appendEvent("content_block_start", fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex))
data, _ := sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex)), "delta.text", partTextResult.String())
appendEvent("content_block_delta", string(data))
params.ResponseType = 1 // Set state to content
params.HasContent = true
}
@@ -224,14 +217,12 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility
params.HasToolUse = true
fcName := functionCallResult.Get("name").String()
fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
// Handle state transitions when switching to function calls
// Close any existing function call block first
if params.ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
params.ResponseType = 0
}
@@ -245,26 +236,21 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Close any other existing content block
if params.ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
appendEvent("content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex))
params.ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data := []byte(fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex))
data, _ = sjson.SetBytes(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.SetBytes(data, "content_block.name", fcName)
appendEvent("content_block_start", string(data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
data, _ = sjson.SetBytes([]byte(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex)), "delta.partial_json", fcArgsResult.Raw)
appendEvent("content_block_delta", string(data))
}
params.ResponseType = 3
params.HasContent = true
@@ -296,10 +282,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
appendFinalEvents(params, &output, false)
}
return []string{output}
return [][]byte{output}
}
func appendFinalEvents(params *Params, output *string, force bool) {
func appendFinalEvents(params *Params, output *[]byte, force bool) {
if params.HasSentFinalEvents {
return
}
@@ -314,9 +300,7 @@ func appendFinalEvents(params *Params, output *string, force bool) {
}
if params.ResponseType != 0 {
*output = *output + "event: content_block_stop\n"
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
*output = *output + "\n\n\n"
*output = translatorcommon.AppendSSEEventString(*output, "content_block_stop", fmt.Sprintf(`{"type":"content_block_stop","index":%d}`, params.ResponseIndex), 3)
params.ResponseType = 0
}
@@ -329,18 +313,16 @@ func appendFinalEvents(params *Params, output *string, force bool) {
}
}
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
delta := []byte(fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens))
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
delta, err = sjson.SetBytes(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n"
*output = translatorcommon.AppendSSEEventString(*output, "message_delta", string(delta), 3)
params.HasSentFinalEvents = true
}
@@ -369,9 +351,9 @@ func resolveStopReason(params *Params) string {
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
// - []byte: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
modelName := gjson.GetBytes(requestRawJSON, "model").String()
root := gjson.ParseBytes(rawJSON)
@@ -388,15 +370,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
}
}
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
responseJSON := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
responseJSON, _ = sjson.SetBytes(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.SetBytes(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.SetBytes(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
responseJSON, err = sjson.SetBytes(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
@@ -407,7 +389,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
if contentArrayInitialized {
return
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content", []byte("[]"))
contentArrayInitialized = true
}
@@ -423,9 +405,9 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return
}
ensureContentArray()
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
textBuilder.Reset()
}
@@ -434,12 +416,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
return
}
ensureContentArray()
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if thinkingSignature != "" {
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
block, _ = sjson.SetBytes(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", block)
thinkingBuilder.Reset()
thinkingSignature = ""
}
@@ -473,18 +455,18 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
flushText()
hasToolCall = true
name := functionCall.Get("name").String()
name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(args.Raw))
}
ensureContentArray()
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
responseJSON, _ = sjson.SetRawBytes(responseJSON, "content.-1", toolBlock)
continue
}
}
@@ -508,17 +490,17 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
}
}
}
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
responseJSON, _ = sjson.SetBytes(responseJSON, "stop_reason", stopReason)
if promptTokens == 0 && outputTokens == 0 {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
responseJSON, _ = sjson.Delete(responseJSON, "usage")
responseJSON, _ = sjson.DeleteBytes(responseJSON, "usage")
}
}
return responseJSON
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -34,10 +34,10 @@ import (
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", modelName)
template := `{"project":"","request":{},"model":""}`
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
template = string(templateBytes)
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
@@ -47,7 +47,8 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
template = string(templateBytes)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
@@ -149,7 +150,8 @@ func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string
raw := response.Raw
name := response.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
raw = string(updated)
}
return raw
}
@@ -157,27 +159,27 @@ func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string
log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse")
if funcResp.Exists() {
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
name := funcResp.Get("name").String()
if strings.TrimSpace(name) == "" {
name = fallbackName
}
fr, _ = sjson.Set(fr, "functionResponse.name", name)
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.Set(fr, "functionResponse.id", id)
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
}
return fr
return string(fr)
}
useName := fallbackName
if useName == "" {
useName = "unknown"
}
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", useName)
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
return fr
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
return string(fr)
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -204,7 +206,7 @@ func fixCLIToolResponse(input string) (string, error) {
}
// Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}`
contentsWrapper := []byte(`{"contents":[]}`)
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -237,16 +239,16 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
@@ -269,7 +271,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse model content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
// Create a new group for tracking responses
group := &FunctionCallGroup{
@@ -283,7 +285,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
} else {
// Non-model content (user, etc.)
@@ -291,7 +293,7 @@ func fixCLIToolResponse(input string) (string, error) {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
}
return true
@@ -303,23 +305,22 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
return result, nil
return string(result), nil
}

View File

@@ -8,8 +8,8 @@ package gemini
import (
"bytes"
"context"
"fmt"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -29,8 +29,8 @@ import (
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
// - [][]byte: The transformed response data in Gemini API format.
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
@@ -44,22 +44,22 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
chunk = restoreUsageMetadata(chunk)
}
} else {
chunkTemplate := "[]"
chunkTemplate := []byte("[]")
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
chunkTemplate, _ = sjson.SetRawBytes(chunkTemplate, "-1", []byte(responseResultItem.Get("response").Raw))
}
}
}
chunk = []byte(chunkTemplate)
chunk = chunkTemplate
}
return []string{string(chunk)}
return [][]byte{chunk}
}
return []string{}
return [][]byte{}
}
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
@@ -73,18 +73,18 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing the response data.
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
return string(chunk)
return chunk
}
return string(rawJSON)
return rawJSON
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.

View File

@@ -59,8 +59,8 @@ func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
if result != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
if string(result) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", string(result), tt.expected)
}
})
}
@@ -87,8 +87,8 @@ func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if results[0] != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
if string(results[0]) != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", string(results[0]), tt.expected)
}
})
}

View File

@@ -286,7 +286,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
continue
}
fid := tc.Get("id").String()
fname := tc.Get("function.name").String()
fname := util.SanitizeFunctionName(tc.Get("function.name").String())
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
@@ -309,7 +309,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", util.SanitizeFunctionName(name))
resp := toolResponses[fid]
if resp == "" {
resp = "{}"
@@ -354,33 +354,39 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw = string(fnRawBytes)
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
fnRawBytes, errSet := sjson.SetBytes([]byte(fnRaw), "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
fnRaw = string(fnRawBytes)
fnRawBytes, errSet = sjson.SetRawBytes([]byte(fnRaw), "parametersJsonSchema.properties", []byte(`{}`))
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw = string(fnRawBytes)
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
fnRawBytes := []byte(fnRaw)
fnRawBytes, _ = sjson.SetBytes(fnRawBytes, "name", util.SanitizeFunctionName(fn.Get("name").String()))
fnRaw, _ = sjson.Delete(string(fnRawBytes), "strict")
if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}

View File

@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -26,6 +27,7 @@ type convertCliResponseToOpenAIChatParams struct {
FunctionIndex int
SawToolCall bool // Tracks if any tool call was seen in the entire stream
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
SanitizedNameMap map[string]string
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
@@ -44,25 +46,29 @@ var functionCallIDCounter uint64
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: 0,
UnixTimestamp: 0,
FunctionIndex: 0,
SanitizedNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
}
}
if (*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap == nil {
(*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
return [][]byte{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
template, _ = sjson.SetBytes(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
@@ -71,14 +77,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
template, _ = sjson.SetBytes(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
template, _ = sjson.SetBytes(template, "id", responseIDResult.String())
}
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
@@ -90,21 +96,21 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
template, err = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
@@ -141,33 +147,33 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
template, _ = sjson.SetBytes(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
toolCallsResult := gjson.GetBytes(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
functionCallTemplate := []byte(`{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`)
fcName := util.RestoreSanitizedToolName((*param).(*convertCliResponseToOpenAIChatParams).SanitizedNameMap, functionCallResult.Get("name").String())
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
@@ -181,16 +187,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
imagesResult := gjson.GetBytes(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images", []byte(`[]`))
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
imageIndex := len(gjson.GetBytes(template, "choices.0.delta.images").Array())
imagePayload := []byte(`{"type":"image_url","image_url":{"url":""}}`)
imagePayload, _ = sjson.SetBytes(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.SetBytes(imagePayload, "image_url.url", imageURL)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
@@ -212,11 +218,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
} else {
finishReason = "stop"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
}
return []string{template}
return [][]byte{template}
}
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
@@ -231,11 +237,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
return []byte{}
}

View File

@@ -19,7 +19,7 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result1) != 1 {
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
}
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
}
@@ -33,13 +33,13 @@ func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
if len(result2) != 1 {
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
}
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr2 != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
}
// Verify native_finish_reason is lowercase upstream value
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
nfr2 := gjson.GetBytes(result2[0], "choices.0.native_finish_reason").String()
if nfr2 != "stop" {
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
}
@@ -58,7 +58,7 @@ func TestFinishReasonStopForNormalText(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "stop" (no tool calls were made)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "stop" {
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
}
@@ -77,7 +77,7 @@ func TestFinishReasonMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "max_tokens"
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "max_tokens" {
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
}
@@ -96,7 +96,7 @@ func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
fr := gjson.GetBytes(result2[0], "choices.0.finish_reason").String()
if fr != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
}
@@ -111,7 +111,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, &param)
// Verify no finish_reason on intermediate chunk
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
fr1 := gjson.GetBytes(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
}
@@ -121,7 +121,7 @@ func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, &param)
// Verify no finish_reason on intermediate chunk
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
fr2 := gjson.GetBytes(result2[0], "choices.0.finish_reason")
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/tidwall/gjson"
)
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
@@ -15,7 +15,7 @@ func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)

View File

@@ -8,7 +8,7 @@ import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
"github.com/tidwall/sjson"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
)
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
@@ -23,15 +23,13 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap each converted response in a "response" object to match Gemini CLI API structure
newOutputs := make([]string, 0)
newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
}
return newOutputs
}
@@ -47,15 +45,13 @@ func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, ori
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
out := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap the converted response in a "response" object to match Gemini CLI API structure
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
return translatorcommon.WrapGeminiCLIResponse(out)
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return GeminiTokenCount(ctx, count)
}

View File

@@ -63,7 +63,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -87,20 +87,20 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
var pendingToolIDs []string
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Generation config extraction from Gemini format
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
// Max output tokens configuration
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := genConfig.Get("topP"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
@@ -110,7 +110,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
}
}
// Include thoughts configuration for reasoning process visibility
@@ -132,30 +132,30 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", level)
}
} else {
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default:
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -169,37 +169,37 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
level, ok := thinking.ConvertBudgetToLevel(budget)
if ok {
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", level)
}
}
} else {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
default:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
}
}
}
@@ -220,9 +220,9 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
})
if systemText.Len() > 0 {
// Create system message in Claude Code format
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}`
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
systemMessage := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
systemMessage, _ = sjson.SetBytes(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRawBytes(out, "messages.-1", systemMessage)
}
}
}
@@ -245,42 +245,42 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Create message structure in Claude Code format
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Text content conversion
if text := part.Get("text"); text.Exists() {
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
textContent := []byte(`{"type":"text","text":""}`)
textContent, _ = sjson.SetBytes(textContent, "text", text.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true
}
// Function call (from model/assistant) conversion to tool use
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
// Generate a unique tool ID and enqueue it for later matching
// with the corresponding functionResponse
toolID := genToolCallID()
pendingToolIDs = append(pendingToolIDs, toolID)
toolUse, _ = sjson.Set(toolUse, "id", toolID)
toolUse, _ = sjson.SetBytes(toolUse, "id", toolID)
if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String())
toolUse, _ = sjson.SetBytes(toolUse, "name", name.String())
}
if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(args.Raw))
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
return true
}
// Function response (from user) conversion to tool result
if fr := part.Get("functionResponse"); fr.Exists() {
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
// Attach the oldest queued tool_id to pair the response
// with its call. If the queue is empty, generate a new id.
@@ -293,41 +293,41 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
// Fallback: generate new ID if no pending tool_use found
toolID = genToolCallID()
}
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID)
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", toolID)
// Extract result content from the function response
if result := fr.Get("response.result"); result.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", result.String())
toolResult, _ = sjson.SetBytes(toolResult, "content", result.String())
} else if response := fr.Get("response"); response.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", response.Raw)
toolResult, _ = sjson.SetBytes(toolResult, "content", response.Raw)
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolResult)
return true
}
// Image content (inline_data) conversion to Claude Code format
if inlineData := part.Get("inline_data"); inlineData.Exists() {
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imageContent := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String())
imageContent, _ = sjson.SetBytes(imageContent, "source.media_type", mimeType.String())
}
if data := inlineData.Get("data"); data.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.data", data.String())
imageContent, _ = sjson.SetBytes(imageContent, "source.data", data.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent)
msg, _ = sjson.SetRawBytes(msg, "content.-1", imageContent)
return true
}
// File data conversion to text content with file info
if fileData := part.Get("file_data"); fileData.Exists() {
// For file data, we'll convert to text content with file info
textContent := `{"type":"text","text":""}`
textContent := []byte(`{"type":"text","text":""}`)
fileInfo := "File: " + fileData.Get("file_uri").String()
if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
fileInfo += " (Type: " + mimeType.String() + ")"
}
textContent, _ = sjson.Set(textContent, "text", fileInfo)
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
textContent, _ = sjson.SetBytes(textContent, "text", fileInfo)
msg, _ = sjson.SetRawBytes(msg, "content.-1", textContent)
return true
}
@@ -336,8 +336,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Only add message if it has content
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRaw(out, "messages.-1", msg)
if contentArray := gjson.GetBytes(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
}
return true
@@ -351,29 +351,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool {
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
anthropicTool := `{"name":"","description":"","input_schema":{}}`
anthropicTool := []byte(`{"name":"","description":"","input_schema":{}}`)
if name := funcDecl.Get("name"); name.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", name.String())
}
if desc := funcDecl.Get("description"); desc.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", desc.String())
}
if params := funcDecl.Get("parameters"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
cleaned := []byte(params.Raw)
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
cleaned := []byte(params.Raw)
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
cleaned, _ = sjson.SetBytes(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", cleaned)
}
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value())
anthropicTools = append(anthropicTools, gjson.ParseBytes(anthropicTool).Value())
return true
})
}
@@ -381,7 +381,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
})
if len(anthropicTools) > 0 {
out, _ = sjson.Set(out, "tools", anthropicTools)
out, _ = sjson.SetBytes(out, "tools", anthropicTools)
}
}
@@ -391,27 +391,27 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() {
case "AUTO":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "NONE":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"none"}`))
case "ANY":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
}
}
}
// Stream setting configuration
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Convert tool parameter types to lowercase for Claude Code compatibility
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
}
return []byte(out)
return out
}

View File

@@ -9,10 +9,10 @@ import (
"bufio"
"bytes"
"context"
"fmt"
"strings"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -30,7 +30,7 @@ type ConvertAnthropicResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
LastStorageOutput []byte
IsStreaming bool
// Streaming state for tool_use assembly
@@ -52,8 +52,8 @@ type ConvertAnthropicResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertAnthropicResponseToGeminiParams{
Model: modelName,
@@ -63,7 +63,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,24 +71,24 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
eventType := root.Get("type").String()
// Base Gemini response template with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
// Map Claude model names back to Gemini model names
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
}
// Set creation time to current time if not provided
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
}
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
switch eventType {
case "message_start":
@@ -97,7 +97,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
}
return []string{}
return [][]byte{}
case "content_block_start":
// Start of a content block - record tool_use name by index for functionCall assembly
@@ -112,7 +112,7 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
}
}
return []string{}
return [][]byte{}
case "content_block_delta":
// Handle content delta (text, thinking, or tool use arguments)
@@ -123,16 +123,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
case "text_delta":
// Regular text content delta for normal response text
if text := delta.Get("text"); text.Exists() && text.String() != "" {
textPart := `{"text":""}`
textPart, _ = sjson.Set(textPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart)
textPart := []byte(`{"text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", textPart)
}
case "thinking_delta":
// Thinking/reasoning content delta for models with reasoning capabilities
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
thinkingPart := `{"thought":true,"text":""}`
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart)
thinkingPart := []byte(`{"thought":true,"text":""}`)
thinkingPart, _ = sjson.SetBytes(thinkingPart, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", thinkingPart)
}
case "input_json_delta":
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
@@ -149,10 +149,10 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if pj := delta.Get("partial_json"); pj.Exists() {
b.WriteString(pj.String())
}
return []string{}
return [][]byte{}
}
}
return []string{template}
return [][]byte{template}
case "content_block_stop":
// End of content block - finalize tool calls if any
@@ -170,16 +170,16 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
}
if name != "" || argsTrim != "" {
functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" {
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", name)
}
if argsTrim != "" {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsTrim))
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// cleanup used state for this index
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
@@ -187,9 +187,9 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
}
return []string{template}
return [][]byte{template}
}
return []string{}
return [][]byte{}
case "message_delta":
// Handle message-level changes (like stop reason and usage information)
@@ -197,15 +197,15 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
switch stopReason.String() {
case "end_turn":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "tool_use":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
case "max_tokens":
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "MAX_TOKENS")
case "stop_sequence":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
default:
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
}
}
}
@@ -216,35 +216,35 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
template, _ = sjson.SetBytes(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
template, _ = sjson.SetBytes(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
}
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
return []string{template}
return [][]byte{template}
case "message_stop":
// Final message with usage information - no additional output needed
return []string{}
return [][]byte{}
case "error":
// Handle error responses and convert to Gemini error format
errorMsg := root.Get("error.message").String()
@@ -253,13 +253,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
}
// Create error response in Gemini format
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg)
return []string{errorResponse}
errorResponse := []byte(`{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`)
errorResponse, _ = sjson.SetBytes(errorResponse, "error.message", errorMsg)
return [][]byte{errorResponse}
default:
// Unknown event type, return empty response
return []string{}
return [][]byte{}
}
}
@@ -275,13 +275,13 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Base Gemini response template for non-streaming with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
streamingEvents := make([][]byte, 0)
@@ -304,15 +304,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
LastStorageOutput: nil,
IsStreaming: false,
ToolUseNames: nil,
ToolUseArgs: nil,
}
// Process each streaming event and collect parts
var allParts []string
var finalUsageJSON string
var allParts [][]byte
var finalUsageJSON []byte
var responseID string
var createdAt int64
@@ -360,15 +360,15 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "text_delta":
// Process regular text content
if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
partJSON := []byte(`{"text":""}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "thinking_delta":
// Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
partJSON := []byte(`{"thought":true,"text":""}`)
partJSON, _ = sjson.SetBytes(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "input_json_delta":
@@ -402,12 +402,12 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
}
}
if name != "" || argsTrim != "" {
functionCallJSON := `{"functionCall":{"name":"","args":{}}}`
functionCallJSON := []byte(`{"functionCall":{"name":"","args":{}}}`)
if name != "" {
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name)
functionCallJSON, _ = sjson.SetBytes(functionCallJSON, "functionCall.name", name)
}
if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
functionCallJSON, _ = sjson.SetRawBytes(functionCallJSON, "functionCall.args", []byte(argsTrim))
}
allParts = append(allParts, functionCallJSON)
// cleanup used state for this index
@@ -422,35 +422,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
case "message_delta":
// Extract final usage information using sjson for token counts and metadata
if usage := root.Get("usage"); usage.Exists() {
usageJSON := `{}`
usageJSON := []byte(`{}`)
// Basic token counts for prompt and completion
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens)
usageJSON, _ = sjson.SetBytes(usageJSON, "cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
usageJSON, _ = sjson.SetBytes(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
usageJSON, _ = sjson.SetBytes(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
finalUsageJSON = usageJSON
}
@@ -459,10 +459,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set response metadata
if responseID != "" {
template, _ = sjson.Set(template, "responseId", responseID)
template, _ = sjson.SetBytes(template, "responseId", responseID)
}
if createdAt > 0 {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
}
// Consolidate consecutive text parts and thinking parts for cleaner output
@@ -470,35 +470,35 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array
if len(consolidatedParts) > 0 {
partsJSON := "[]"
partsJSON := []byte(`[]`)
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
partsJSON, _ = sjson.SetRawBytes(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts", partsJSON)
}
// Set usage metadata
if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
if len(finalUsageJSON) > 0 {
template, _ = sjson.SetRawBytes(template, "usageMetadata", finalUsageJSON)
}
return template
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
// This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []string) []string {
func consolidateParts(parts [][]byte) [][]byte {
if len(parts) == 0 {
return parts
}
var consolidated []string
var consolidated [][]byte
var currentTextPart strings.Builder
var currentThoughtPart strings.Builder
var hasText, hasThought bool
@@ -506,8 +506,8 @@ func consolidateParts(parts []string) []string {
flushText := func() {
// Flush accumulated text content to the consolidated parts array
if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
textPartJSON := []byte(`{"text":""}`)
textPartJSON, _ = sjson.SetBytes(textPartJSON, "text", currentTextPart.String())
consolidated = append(consolidated, textPartJSON)
currentTextPart.Reset()
hasText = false
@@ -517,8 +517,8 @@ func consolidateParts(parts []string) []string {
flushThought := func() {
// Flush accumulated thinking content to the consolidated parts array
if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
thoughtPartJSON := []byte(`{"thought":true,"text":""}`)
thoughtPartJSON, _ = sjson.SetBytes(thoughtPartJSON, "text", currentThoughtPart.String())
consolidated = append(consolidated, thoughtPartJSON)
currentThoughtPart.Reset()
hasThought = false
@@ -526,7 +526,7 @@ func consolidateParts(parts []string) []string {
}
for _, partJSON := range parts {
part := gjson.Parse(partJSON)
part := gjson.ParseBytes(partJSON)
if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part
flushText()

View File

@@ -61,7 +61,7 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude Code API template with default max_tokens value
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -79,20 +79,20 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
@@ -100,13 +100,13 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -128,19 +128,19 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens configuration with fallback to default value
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
out, _ = sjson.SetBytes(out, "temperature", temp.Float())
} else if topP := root.Get("top_p"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
@@ -152,60 +152,53 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
out, _ = sjson.SetBytes(out, "stop_sequences", stopSequences)
}
} else {
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()})
out, _ = sjson.SetBytes(out, "stop_sequences", []string{stop.String()})
}
}
// Stream configuration to enable or disable streaming responses
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messageIndex := 0
systemMessageIndex := -1
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
contentResult := message.Get("content")
switch role {
case "system":
if systemMessageIndex == -1 {
systemMsg := `{"role":"user","content":[]}`
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex
messageIndex++
}
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", contentResult.String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", contentResult.String())
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRawBytes(out, "system.-1", textPart)
}
return true
})
}
case "user", "assistant":
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
// Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
part := `{"type":"text","text":""}`
part, _ = sjson.Set(part, "text", contentResult.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{"type":"text","text":""}`)
part, _ = sjson.SetBytes(part, "text", contentResult.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(claudePart))
}
return true
})
@@ -221,9 +214,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
}
function := toolCall.Get("function")
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.SetBytes(toolUse, "id", toolCallID)
toolUse, _ = sjson.SetBytes(toolUse, "name", function.Get("name").String())
// Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() {
@@ -231,24 +224,24 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte("{}"))
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
msg, _ = sjson.SetRawBytes(msg, "content.-1", toolUse)
}
return true
})
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++
case "tool":
@@ -256,19 +249,29 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
toolCallID := message.Get("tool_call_id").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg := []byte(`{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`)
msg, _ = sjson.SetBytes(msg, "content.0.tool_use_id", toolCallID)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
msg, _ = sjson.SetRawBytes(msg, "content.0.content", []byte(toolResultContent))
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
msg, _ = sjson.SetBytes(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
messageIndex++
}
return true
})
// Preserve a minimal conversational turn for system-only inputs.
// Claude payloads with top-level system instructions but no messages are risky for downstream validation.
if messageIndex == 0 {
system := gjson.GetBytes(out, "system")
if system.Exists() && system.IsArray() && len(system.Array()) > 0 {
fallbackMsg := []byte(`{"role":"user","content":[{"type":"text","text":""}]}`)
out, _ = sjson.SetRawBytes(out, "messages.-1", fallbackMsg)
}
}
}
// Tools mapping: OpenAI tools -> Claude Code tools
@@ -277,25 +280,25 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" {
function := tool.Get("function")
anthropicTool := `{"name":"","description":""}`
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
anthropicTool := []byte(`{"name":"","description":""}`)
anthropicTool, _ = sjson.SetBytes(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.SetBytes(anthropicTool, "description", function.Get("description").String())
// Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
anthropicTool, _ = sjson.SetRawBytes(anthropicTool, "input_schema", []byte(parameters.Raw))
}
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
out, _ = sjson.SetRawBytes(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
}
return true
})
if !hasAnthropicTools {
out, _ = sjson.Delete(out, "tools")
out, _ = sjson.DeleteBytes(out, "tools")
}
}
@@ -308,31 +311,31 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none":
// Don't set tool_choice, Claude Code will not use tools
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
case gjson.JSON:
// Specific tool choice mapping
if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"type":"tool","name":""}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
toolChoiceJSON := []byte(`{"type":"tool","name":""}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
return out
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.Get("text").String())
return string(textPart)
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
@@ -345,10 +348,10 @@ func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
docPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
docPart, _ = sjson.SetBytes(docPart, "source.media_type", mediaType)
docPart, _ = sjson.SetBytes(docPart, "source.data", data)
return string(docPart)
}
}
}
@@ -373,15 +376,15 @@ func convertOpenAIImageURLToClaudePart(imageURL string) string {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
imagePart := []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
imagePart, _ = sjson.SetBytes(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.SetBytes(imagePart, "source.data", parts[1])
return string(imagePart)
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
imagePart := []byte(`{"type":"image","source":{"type":"url","url":""}}`)
imagePart, _ = sjson.SetBytes(imagePart, "source.url", imageURL)
return string(imagePart)
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
@@ -394,28 +397,28 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
}
if content.IsArray() {
claudeContent := "[]"
claudeContent := []byte("[]")
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
textPart := []byte(`{"type":"text","text":""}`)
textPart, _ = sjson.SetBytes(textPart, "text", part.String())
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
return string(claudeContent), true
}
return content.Raw, false
@@ -424,9 +427,9 @@ func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
claudeContent := []byte("[]")
claudeContent, _ = sjson.SetRawBytes(claudeContent, "-1", []byte(claudePart))
return string(claudeContent), true
}
return content.Raw, false
}

View File

@@ -135,3 +135,111 @@ func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
t.Fatalf("Unexpected image URL: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemRoleBecomesTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system")
if !system.IsArray() {
t.Fatalf("Expected top-level system array, got %s", system.Raw)
}
if len(system.Array()) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system.Array()), system.Raw)
}
if got := system.Get("0.type").String(); got != "text" {
t.Fatalf("Expected system block type %q, got %q", "text", got)
}
if got := system.Get("0.text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_MultipleSystemMessagesMergedIntoTopLevelSystem(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "Rule 1"},
{"role": "system", "content": [{"type": "text", "text": "Rule 2"}]},
{"role": "user", "content": "Hello"}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 2 {
t.Fatalf("Expected 2 system blocks, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "Rule 1" {
t.Fatalf("Expected first system text %q, got %q", "Rule 1", got)
}
if got := system[1].Get("text").String(); got != "Rule 2" {
t.Fatalf("Expected second system text %q, got %q", "Rule 2", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 non-system message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected remaining message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.text").String(); got != "Hello" {
t.Fatalf("Expected user text %q, got %q", "Hello", got)
}
}
func TestConvertOpenAIRequestToClaude_SystemOnlyInputKeepsFallbackUserMessage(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
system := resultJSON.Get("system").Array()
if len(system) != 1 {
t.Fatalf("Expected 1 system block, got %d. System: %s", len(system), resultJSON.Get("system").Raw)
}
if got := system[0].Get("text").String(); got != "You are a helpful assistant." {
t.Fatalf("Expected system text %q, got %q", "You are a helpful assistant.", got)
}
messages := resultJSON.Get("messages").Array()
if len(messages) != 1 {
t.Fatalf("Expected 1 fallback message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if got := messages[0].Get("role").String(); got != "user" {
t.Fatalf("Expected fallback message role %q, got %q", "user", got)
}
if got := messages[0].Get("content.0.type").String(); got != "text" {
t.Fatalf("Expected fallback content type %q, got %q", "text", got)
}
if got := messages[0].Get("content.0.text").String(); got != "" {
t.Fatalf("Expected fallback text %q, got %q", "", got)
}
}

View File

@@ -36,6 +36,18 @@ type ToolCallAccumulator struct {
Arguments strings.Builder
}
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
inputTokens := usage.Get("input_tokens").Int()
completionTokens = usage.Get("output_tokens").Int()
cachedTokens = usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
totalTokens = promptTokens + completionTokens
return promptTokens, completionTokens, totalTokens, cachedTokens
}
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
@@ -48,8 +60,8 @@ type ToolCallAccumulator struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
var localParam any
if param == nil {
param = &localParam
@@ -63,7 +75,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -71,19 +83,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
eventType := root.Get("type").String()
// Base OpenAI streaming response template
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`)
// Set model
if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
}
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
}
switch eventType {
@@ -93,19 +105,19 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "model", modelName)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
// Set initial role to assistant for the response
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
// Initialize tool calls accumulator for tracking tool call progress
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
}
return []string{template}
return [][]byte{template}
case "content_block_start":
// Start of a content block (text, tool use, or reasoning)
@@ -128,10 +140,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
// Don't output anything yet - wait for complete tool call
return []string{}
return [][]byte{}
}
}
return []string{}
return [][]byte{}
case "content_block_delta":
// Handle content delta (text, tool use arguments, or reasoning content)
@@ -143,13 +155,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "text_delta":
// Text content delta - send incremental text updates
if text := delta.Get("text"); text.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.content", text.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.content", text.String())
hasContent = true
}
case "thinking_delta":
// Accumulate reasoning/thinking content
if thinking := delta.Get("thinking"); thinking.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", thinking.String())
hasContent = true
}
case "input_json_delta":
@@ -163,13 +175,13 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
}
}
// Don't output anything yet - wait for complete tool call
return []string{}
return [][]byte{}
}
}
if hasContent {
return []string{template}
return [][]byte{template}
} else {
return []string{}
return [][]byte{}
}
case "content_block_stop":
@@ -182,63 +194,60 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" {
arguments = "{}"
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.SetBytes(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
// Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
return []string{template}
return [][]byte{template}
}
}
return []string{}
return [][]byte{}
case "message_delta":
// Handle message-level changes including stop reason and usage
if delta := root.Get("delta"); delta.Exists() {
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
}
}
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
}
return []string{template}
return [][]byte{template}
case "message_stop":
// Final message event - no additional output needed
return []string{}
return [][]byte{}
case "ping":
// Ping events for keeping connection alive - no output needed
return []string{}
return [][]byte{}
case "error":
// Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() {
errorJSON := `{"error":{"message":"","type":""}}`
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
return []string{errorJSON}
errorJSON := []byte(`{"error":{"message":"","type":""}}`)
errorJSON, _ = sjson.SetBytes(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.SetBytes(errorJSON, "error.type", errorData.Get("type").String())
return [][]byte{errorJSON}
}
return []string{}
return [][]byte{}
default:
// Unknown event type - ignore
return []string{}
return [][]byte{}
}
}
@@ -270,8 +279,8 @@ func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
chunks := make([][]byte, 0)
lines := bytes.Split(rawJSON, []byte("\n"))
@@ -283,7 +292,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
// Base OpenAI non-streaming response template
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
out := []byte(`{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
var messageID string
var model string
@@ -366,32 +375,29 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
}
}
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.Set(out, "id", messageID)
out, _ = sjson.Set(out, "created", createdAt)
out, _ = sjson.Set(out, "model", model)
out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.SetBytes(out, "created", createdAt)
out, _ = sjson.SetBytes(out, "model", model)
// Set message content by combining all text parts
messageContent := strings.Join(contentParts, "")
out, _ = sjson.Set(out, "choices.0.message.content", messageContent)
out, _ = sjson.SetBytes(out, "choices.0.message.content", messageContent)
// Add reasoning content if available (following OpenAI reasoning format)
if len(reasoningParts) > 0 {
reasoningContent := strings.Join(reasoningParts, "")
// Add reasoning as a separate field in the message
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent)
out, _ = sjson.SetBytes(out, "choices.0.message.reasoning", reasoningContent)
}
// Set tool calls if any were accumulated during processing
@@ -417,19 +423,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
out, _ = sjson.SetBytes(out, idPath, accumulator.ID)
out, _ = sjson.SetBytes(out, typePath, "function")
out, _ = sjson.SetBytes(out, namePath, accumulator.Name)
out, _ = sjson.SetBytes(out, argumentsPath, arguments)
toolCallsCount++
}
if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", "tool_calls")
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
out, _ = sjson.SetBytes(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
return out

View File

@@ -0,0 +1,58 @@
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":13,"output_tokens":4,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}`),
&param,
)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}

View File

@@ -49,7 +49,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
out := []byte(fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID))
root := gjson.ParseBytes(rawJSON)
@@ -67,20 +67,20 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.DeleteBytes(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
out, _ = sjson.SetBytes(out, "thinking.type", "adaptive")
out, _ = sjson.DeleteBytes(out, "thinking.budget_tokens")
out, _ = sjson.SetBytes(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
@@ -88,13 +88,13 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.SetBytes(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
out, _ = sjson.SetBytes(out, "thinking.type", "enabled")
out, _ = sjson.SetBytes(out, "thinking.budget_tokens", budget)
}
}
}
@@ -114,15 +114,15 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Max tokens
if mot := root.Get("max_output_tokens"); mot.Exists() {
out, _ = sjson.Set(out, "max_tokens", mot.Int())
out, _ = sjson.SetBytes(out, "max_tokens", mot.Int())
}
// Stream
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// instructions -> as a leading message (use role user for Claude API compatibility)
instructionsText := ""
@@ -130,9 +130,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
instructionsText = instr.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
}
}
@@ -156,9 +156,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
instructionsText = builder.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
sysMsg := []byte(`{"role":"user","content":""}`)
sysMsg, _ = sjson.SetBytes(sysMsg, "content", instructionsText)
out, _ = sjson.SetRawBytes(out, "messages.-1", sysMsg)
extractedFromSystem = true
}
}
@@ -193,9 +193,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if t := part.Get("text"); t.Exists() {
txt := t.String()
textAggregate.WriteString(txt)
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", txt)
partsJSON = append(partsJSON, contentPart)
contentPart := []byte(`{"type":"text","text":""}`)
contentPart, _ = sjson.SetBytes(contentPart, "text", txt)
partsJSON = append(partsJSON, string(contentPart))
}
if ptype == "input_text" {
role = "user"
@@ -208,7 +208,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
url = part.Get("url").String()
}
if url != "" {
var contentPart string
var contentPart []byte
if strings.HasPrefix(url, "data:") {
trimmed := strings.TrimPrefix(url, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
@@ -221,16 +221,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1]
}
if data != "" {
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
contentPart = []byte(`{"type":"image","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
}
} else {
contentPart = `{"type":"image","source":{"type":"url","url":""}}`
contentPart, _ = sjson.Set(contentPart, "source.url", url)
contentPart = []byte(`{"type":"image","source":{"type":"url","url":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.url", url)
}
if contentPart != "" {
partsJSON = append(partsJSON, contentPart)
if len(contentPart) > 0 {
partsJSON = append(partsJSON, string(contentPart))
if role == "" {
role = "user"
}
@@ -252,10 +252,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
contentPart := []byte(`{"type":"document","source":{"type":"base64","media_type":"","data":""}}`)
contentPart, _ = sjson.SetBytes(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.SetBytes(contentPart, "source.data", data)
partsJSON = append(partsJSON, string(contentPart))
if role == "" {
role = "user"
}
@@ -280,24 +280,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
msg, _ = sjson.DeleteBytes(msg, "content")
textPart := gjson.Parse(partsJSON[0])
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String())
msg, _ = sjson.SetBytes(msg, "content", textPart.Get("text").String())
} else {
for _, partJSON := range partsJSON {
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON)
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
}
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
} else if textAggregate.Len() > 0 || role == "system" {
msg := `{"role":"","content":""}`
msg, _ = sjson.Set(msg, "role", role)
msg, _ = sjson.Set(msg, "content", textAggregate.String())
out, _ = sjson.SetRaw(out, "messages.-1", msg)
msg := []byte(`{"role":"","content":""}`)
msg, _ = sjson.SetBytes(msg, "role", role)
msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
}
case "function_call":
@@ -309,31 +309,31 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
name := item.Get("name").String()
argsStr := item.Get("arguments").String()
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name)
toolUse := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolUse, _ = sjson.SetBytes(toolUse, "id", callID)
toolUse, _ = sjson.SetBytes(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
toolUse, _ = sjson.SetRawBytes(toolUse, "input", []byte(argsJSON.Raw))
}
}
asst := `{"role":"assistant","content":[]}`
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse)
out, _ = sjson.SetRaw(out, "messages.-1", asst)
asst := []byte(`{"role":"assistant","content":[]}`)
asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
case "function_call_output":
// Map to user tool_result
callID := item.Get("call_id").String()
outputStr := item.Get("output").String()
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.Set(toolResult, "content", outputStr)
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
usr := `{"role":"user","content":[]}`
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult)
out, _ = sjson.SetRaw(out, "messages.-1", usr)
usr := []byte(`{"role":"user","content":[]}`)
usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
}
return true
})
@@ -341,27 +341,27 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := "[]"
toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := `{"name":"","description":"","input_schema":{}}`
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.Set(tJSON, "name", n.String())
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
}
if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.Set(tJSON, "description", d.String())
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
}
if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON)
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
return true
})
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "tools", toolsJSON)
}
}
@@ -371,23 +371,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String:
switch toolChoice.String() {
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"auto"}`))
case "none":
// Leave unset; implies no tools
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
return out
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -50,12 +51,12 @@ func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
return nil
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
func emitEvent(event string, payload []byte) []byte {
return translatorcommon.SSEEventData(event, payload)
}
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
}
@@ -63,12 +64,12 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// Expect `data: {..}` from Claude clients
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON)
ev := root.Get("type").String()
var out []string
var out [][]byte
nextSeq := func() int { st.Seq++; return st.Seq }
@@ -105,16 +106,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
}
// response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
created, _ = sjson.SetBytes(created, "response.id", st.ResponseID)
created, _ = sjson.SetBytes(created, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.created", created))
// response.in_progress
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
inprog := []byte(`{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`)
inprog, _ = sjson.SetBytes(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.SetBytes(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.SetBytes(inprog, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.in_progress", inprog))
}
case "content_block_start":
@@ -128,25 +129,25 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
// open message item + content part
st.InTextBlock = true
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.added", item))
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID)
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.added", part))
} else if typ == "tool_use" {
st.InFuncBlock = true
st.CurrentFCID = cb.Get("id").String()
name := cb.Get("name").String()
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.Set(item, "item.name", name)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.SetBytes(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.SetBytes(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item))
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
@@ -160,16 +161,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.ReasoningIndex = idx
st.ReasoningBuf.Reset()
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
item, _ = sjson.SetBytes(item, "output_index", idx)
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningItemID)
out = append(out, emitEvent("response.output_item.added", item))
// add a summary part placeholder
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID)
part, _ = sjson.Set(part, "output_index", idx)
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
part, _ = sjson.SetBytes(part, "item_id", st.ReasoningItemID)
part, _ = sjson.SetBytes(part, "output_index", idx)
out = append(out, emitEvent("response.reasoning_summary_part.added", part))
st.ReasoningPartAdded = true
}
@@ -181,10 +182,10 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
dt := d.Get("type").String()
if dt == "text_delta" {
if t := d.Get("text"); t.Exists() {
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.Set(msg, "delta", t.String())
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
@@ -196,22 +197,22 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.FuncArgsBuf[idx] = &strings.Builder{}
}
st.FuncArgsBuf[idx].WriteString(pj.String())
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.Set(msg, "output_index", idx)
msg, _ = sjson.Set(msg, "delta", pj.String())
msg := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.SetBytes(msg, "output_index", idx)
msg, _ = sjson.SetBytes(msg, "delta", pj.String())
out = append(out, emitEvent("response.function_call_arguments.delta", msg))
}
} else if dt == "thinking_delta" {
if st.ReasoningActive {
if t := d.Get("thinking"); t.Exists() {
st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", t.String())
msg := []byte(`{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`)
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
msg, _ = sjson.SetBytes(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.SetBytes(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.SetBytes(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
}
}
@@ -219,17 +220,17 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
@@ -239,34 +240,34 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
args = buf.String()
}
}
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", args)
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.SetBytes(fcDone, "output_index", idx)
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.InFuncBlock = false
} else if st.ReasoningActive {
full := st.ReasoningBuf.String()
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", full)
textDone := []byte(`{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`)
textDone, _ = sjson.SetBytes(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.SetBytes(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.SetBytes(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.SetBytes(textDone, "text", full)
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", full)
partDone := []byte(`{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.SetBytes(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.SetBytes(partDone, "part.text", full)
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
st.ReasoningActive = false
st.ReasoningPartAdded = false
@@ -284,92 +285,92 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
case "message_stop":
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String())
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool())
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value())
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value())
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String())
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value())
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
// Build response.output from aggregated state
outputsWrapper := `{"arr":[]}`
outputsWrapper := []byte(`{"arr":[]}`)
// reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", st.ReasoningItemID)
item, _ = sjson.SetBytes(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
// assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", st.CurrentMsgID)
item, _ = sjson.SetBytes(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
// function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 {
@@ -396,16 +397,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
reasoningTokens := int64(0)
@@ -414,15 +415,15 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
usagePresent := st.UsageSeen || reasoningTokens > 0
if usagePresent {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.OutputTokens)
if reasoningTokens > 0 {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
total := st.InputTokens + st.OutputTokens
if total > 0 || st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
}
out = append(out, emitEvent("response.completed", completed))
@@ -432,7 +433,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
}
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
// We follow the same aggregation logic as the streaming variant but produce
// one final object matching docs/out.json structure.
@@ -455,7 +456,7 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Base OpenAI Responses (non-stream) object
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`
out := []byte(`{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`)
// Aggregation state
var (
@@ -557,88 +558,88 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
}
// Populate base fields
out, _ = sjson.Set(out, "id", responseID)
out, _ = sjson.Set(out, "created_at", createdAt)
out, _ = sjson.SetBytes(out, "id", responseID)
out, _ = sjson.SetBytes(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String())
out, _ = sjson.SetBytes(out, "instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
out, _ = sjson.Set(out, "max_output_tokens", v.Int())
out, _ = sjson.SetBytes(out, "max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "max_tool_calls", v.Int())
out, _ = sjson.SetBytes(out, "max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
out, _ = sjson.Set(out, "model", v.String())
out, _ = sjson.SetBytes(out, "model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool())
out, _ = sjson.SetBytes(out, "parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
out, _ = sjson.Set(out, "previous_response_id", v.String())
out, _ = sjson.SetBytes(out, "previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
out, _ = sjson.Set(out, "prompt_cache_key", v.String())
out, _ = sjson.SetBytes(out, "prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
out, _ = sjson.Set(out, "reasoning", v.Value())
out, _ = sjson.SetBytes(out, "reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
out, _ = sjson.Set(out, "safety_identifier", v.String())
out, _ = sjson.SetBytes(out, "safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
out, _ = sjson.Set(out, "service_tier", v.String())
out, _ = sjson.SetBytes(out, "service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
out, _ = sjson.Set(out, "store", v.Bool())
out, _ = sjson.SetBytes(out, "store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
out, _ = sjson.Set(out, "temperature", v.Float())
out, _ = sjson.SetBytes(out, "temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
out, _ = sjson.Set(out, "text", v.Value())
out, _ = sjson.SetBytes(out, "text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
out, _ = sjson.Set(out, "tool_choice", v.Value())
out, _ = sjson.SetBytes(out, "tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
out, _ = sjson.Set(out, "tools", v.Value())
out, _ = sjson.SetBytes(out, "tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
out, _ = sjson.Set(out, "top_logprobs", v.Int())
out, _ = sjson.SetBytes(out, "top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
out, _ = sjson.Set(out, "top_p", v.Float())
out, _ = sjson.SetBytes(out, "top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
out, _ = sjson.Set(out, "truncation", v.String())
out, _ = sjson.SetBytes(out, "truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
out, _ = sjson.Set(out, "user", v.Value())
out, _ = sjson.SetBytes(out, "user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
out, _ = sjson.Set(out, "metadata", v.Value())
out, _ = sjson.SetBytes(out, "metadata", v.Value())
}
}
// Build output array
outputsWrapper := `{"arr":[]}`
outputsWrapper := []byte(`{"arr":[]}`)
if reasoningBuf.Len() > 0 {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", reasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", reasoningItemID)
item, _ = sjson.SetBytes(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
if currentMsgID != "" || textBuf.Len() > 0 {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", currentMsgID)
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", currentMsgID)
item, _ = sjson.SetBytes(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
if len(toolCalls) > 0 {
// Preserve index order
@@ -659,28 +660,28 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" {
args = "{}"
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.id)
item, _ = sjson.Set(item, "name", st.name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", st.id)
item, _ = sjson.SetBytes(item, "name", st.name)
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRawBytes(out, "output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
// Usage
total := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", total)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", total)
if reasoningBuf.Len() > 0 {
// Rough estimate similar to chat completions
reasoningTokens := int64(len(reasoningBuf.String()) / 4)
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
}

View File

@@ -36,15 +36,15 @@ import (
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := `{"model":"","instructions":"","input":[]}`
template := []byte(`{"model":"","instructions":"","input":[]}`)
rootResult := gjson.ParseBytes(rawJSON)
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
// Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system")
if systemsResult.Exists() {
message := `{"type":"message","role":"developer","content":[]}`
message := []byte(`{"type":"message","role":"developer","content":[]}`)
contentIndex := 0
appendSystemText := func(text string) {
@@ -52,8 +52,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
return
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
@@ -70,7 +70,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
template, _ = sjson.SetRawBytes(template, "input.-1", message)
}
}
@@ -83,9 +83,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
messageResult := messageResults[i]
messageRole := messageResult.Get("role").String()
newMessage := func() string {
msg := `{"type": "message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", messageRole)
newMessage := func() []byte {
msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", messageRole)
return msg
}
@@ -95,7 +95,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
flushMessage := func() {
if hasContent {
template, _ = sjson.SetRaw(template, "input.-1", message)
template, _ = sjson.SetRawBytes(template, "input.-1", message)
message = newMessage()
contentIndex = 0
hasContent = false
@@ -107,15 +107,15 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if messageRole == "assistant" {
partType = "output_text"
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
hasContent = true
}
appendImageContent := func(dataURL string) {
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.SetBytes(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
contentIndex++
hasContent = true
}
@@ -151,8 +151,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
case "tool_use":
flushMessage()
functionCallMessage := `{"type":"function_call"}`
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
functionCallMessage := []byte(`{"type":"function_call"}`)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
{
name := messageContentResult.Get("name").String()
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
@@ -161,19 +161,19 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "name", name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallMessage)
case "tool_result":
flushMessage()
functionCallOutputMessage := `{"type":"function_call_output"}`
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
contentResult := messageContentResult.Get("content")
if contentResult.IsArray() {
toolResultContentIndex := 0
toolResultContent := `[]`
toolResultContent := []byte(`[]`)
contentResults := contentResult.Array()
for k := 0; k < len(contentResults); k++ {
toolResultContentType := contentResults[k].Get("type").String()
@@ -194,27 +194,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContentIndex++
}
}
} else if toolResultContentType == "text" {
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.SetBytes(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContentIndex++
}
}
if toolResultContent != `[]` {
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
if toolResultContentIndex > 0 {
functionCallOutputMessage, _ = sjson.SetRawBytes(functionCallOutputMessage, "output", toolResultContent)
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
template, _ = sjson.SetRawBytes(template, "input.-1", functionCallOutputMessage)
}
}
flushMessage()
@@ -229,8 +229,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Convert tools declarations to the expected format for the Codex API.
toolsResult := rootResult.Get("tools")
if toolsResult.IsArray() {
template, _ = sjson.SetRaw(template, "tools", `[]`)
template, _ = sjson.Set(template, "tool_choice", `auto`)
template, _ = sjson.SetRawBytes(template, "tools", []byte(`[]`))
template, _ = sjson.SetBytes(template, "tool_choice", `auto`)
toolResults := toolsResult.Array()
// Build short name map from declared tools
var names []string
@@ -246,11 +246,11 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Special handling: map Claude web search tool to Codex web_search
if toolResult.Get("type").String() == "web_search_20250305" {
// Replace the tool content entirely with {"type":"web_search"}
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`)
template, _ = sjson.SetRawBytes(template, "tools.-1", []byte(`{"type":"web_search"}`))
continue
}
tool := toolResult.Raw
tool, _ = sjson.Set(tool, "type", "function")
tool := []byte(toolResult.Raw)
tool, _ = sjson.SetBytes(tool, "type", "function")
// Apply shortened name if needed
if v := toolResult.Get("name"); v.Exists() {
name := v.String()
@@ -259,20 +259,26 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
tool, _ = sjson.SetBytes(tool, "name", name)
}
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
tool, _ = sjson.Delete(tool, "input_schema")
tool, _ = sjson.Delete(tool, "parameters.$schema")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
tool, _ = sjson.Set(tool, "strict", false)
template, _ = sjson.SetRaw(template, "tools.-1", tool)
tool, _ = sjson.SetRawBytes(tool, "parameters", []byte(normalizeToolParameters(toolResult.Get("input_schema").Raw)))
tool, _ = sjson.DeleteBytes(tool, "input_schema")
tool, _ = sjson.DeleteBytes(tool, "parameters.$schema")
tool, _ = sjson.DeleteBytes(tool, "cache_control")
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
tool, _ = sjson.SetBytes(tool, "strict", false)
template, _ = sjson.SetRawBytes(template, "tools.-1", tool)
}
}
// Default to parallel tool calls unless tool_choice explicitly disables them.
parallelToolCalls := true
if disableParallelToolUse := rootResult.Get("tool_choice.disable_parallel_tool_use"); disableParallelToolUse.Exists() {
parallelToolCalls = !disableParallelToolUse.Bool()
}
// Add additional configuration parameters for the Codex API.
template, _ = sjson.Set(template, "parallel_tool_calls", true)
template, _ = sjson.SetBytes(template, "parallel_tool_calls", parallelToolCalls)
// Convert thinking.budget_tokens to reasoning.effort.
reasoningEffort := "medium"
@@ -303,13 +309,13 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
}
}
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.Set(template, "reasoning.summary", "auto")
template, _ = sjson.Set(template, "stream", true)
template, _ = sjson.Set(template, "store", false)
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
template, _ = sjson.SetBytes(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.SetBytes(template, "reasoning.summary", "auto")
template, _ = sjson.SetBytes(template, "stream", true)
template, _ = sjson.SetBytes(template, "store", false)
template, _ = sjson.SetBytes(template, "include", []string{"reasoning.encrypted_content"})
return []byte(template)
return template
}
// shortenNameIfNeeded applies a simple shortening rule for a single name.
@@ -412,15 +418,15 @@ func normalizeToolParameters(raw string) string {
if raw == "" || raw == "null" || !gjson.Valid(raw) {
return `{"type":"object","properties":{}}`
}
schema := raw
result := gjson.Parse(raw)
schema := []byte(raw)
schemaType := result.Get("type").String()
if schemaType == "" {
schema, _ = sjson.Set(schema, "type", "object")
schema, _ = sjson.SetBytes(schema, "type", "object")
schemaType = "object"
}
if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRaw(schema, "properties", `{}`)
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
}
return schema
return string(schema)
}

View File

@@ -87,3 +87,49 @@ func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
})
}
}
func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantParallelToolCalls bool
}{
{
name: "Default to true when tool_choice.disable_parallel_tool_use is absent",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
{
name: "Disable parallel tool calls when client opts out",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": true},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: false,
},
{
name: "Keep parallel tool calls enabled when client explicitly allows them",
inputJSON: `{
"model": "claude-3-opus",
"tool_choice": {"disable_parallel_tool_use": false},
"messages": [{"role": "user", "content": "hello"}]
}`,
wantParallelToolCalls: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
if got := resultJSON.Get("parallel_tool_calls").Bool(); got != tt.wantParallelToolCalls {
t.Fatalf("parallel_tool_calls = %v, want %v. Output: %s", got, tt.wantParallelToolCalls, string(result))
}
})
}
}

View File

@@ -9,9 +9,9 @@ package claude
import (
"bytes"
"context"
"fmt"
"strings"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -43,8 +43,8 @@ type ConvertCodexResponseToClaudeParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
@@ -54,95 +54,85 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
output := ""
output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
template := ""
var template []byte
if typeStr == "response.created" {
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.SetBytes(template, "message.id", rootResult.Get("response.id").String())
output = "event: message_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
} else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
stopReason := rootResult.Get("response.stop_reason").String()
if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
} else if stopReason == "max_tokens" || stopReason == "stop" {
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
template, _ = sjson.SetBytes(template, "delta.stop_reason", stopReason)
} else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
template, _ = sjson.SetBytes(template, "delta.stop_reason", "end_turn")
}
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
template, _ = sjson.SetBytes(template, "usage.input_tokens", inputTokens)
template, _ = sjson.SetBytes(template, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
template, _ = sjson.SetBytes(template, "usage.cache_read_input_tokens", cachedTokens)
}
output = "event: message_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output += "event: message_stop\n"
output += `data: {"type":"message_stop"}`
output += "\n\n"
output = translatorcommon.AppendSSEEventBytes(output, "message_delta", template, 2)
output = translatorcommon.AppendSSEEventBytes(output, "message_stop", []byte(`{"type":"message_stop"}`), 2)
} else if typeStr == "response.output_item.added" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
@@ -150,37 +140,33 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
if orig, ok := rev[name]; ok {
name = orig
}
template, _ = sjson.Set(template, "content_block.name", name)
template, _ = sjson.SetBytes(template, "content_block.name", name)
}
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
}
} else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
@@ -189,17 +175,16 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
}
}
}
return []string{output}
return [][]byte{output}
}
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
@@ -214,28 +199,28 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string {
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
rootResult := gjson.ParseBytes(rawJSON)
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
responseData := rootResult.Get("response")
if !responseData.Exists() {
return ""
return []byte{}
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
out := []byte(`{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`)
out, _ = sjson.SetBytes(out, "id", responseData.Get("id").String())
out, _ = sjson.SetBytes(out, "model", responseData.Get("model").String())
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.SetBytes(out, "usage.input_tokens", inputTokens)
out, _ = sjson.SetBytes(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
out, _ = sjson.SetBytes(out, "usage.cache_read_input_tokens", cachedTokens)
}
hasToolCall := false
@@ -276,9 +261,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
}
if thinkingBuilder.Len() > 0 {
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() {
@@ -287,9 +272,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" {
text := part.Get("text").String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
}
return true
@@ -297,9 +282,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else {
text := content.String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
block := []byte(`{"type":"text","text":""}`)
block, _ = sjson.SetBytes(block, "text", text)
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
}
}
@@ -310,9 +295,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
@@ -320,23 +305,23 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
inputRaw = argsJSON.Raw
}
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
toolBlock, _ = sjson.SetRawBytes(toolBlock, "input", []byte(inputRaw))
out, _ = sjson.SetRawBytes(out, "content.-1", toolBlock)
}
return true
})
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
out, _ = sjson.SetBytes(out, "stop_reason", stopReason.String())
} else if hasToolCall {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
out, _ = sjson.SetBytes(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
out, _ = sjson.SetBytes(out, "stop_reason", "end_turn")
}
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
out, _ = sjson.SetRawBytes(out, "stop_sequence", []byte(stopSequence.Raw))
}
return out
@@ -386,6 +371,6 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}

View File

@@ -6,10 +6,9 @@ package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/sjson"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
)
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
@@ -24,14 +23,12 @@ import (
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
newOutputs := make([]string, 0)
newOutputs := make([][]byte, 0, len(outputs))
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
newOutputs = append(newOutputs, translatorcommon.WrapGeminiCLIResponse(outputs[i]))
}
return newOutputs
}
@@ -47,15 +44,12 @@ func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, orig
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// log.Debug(string(rawJSON))
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
// - []byte: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
out := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
return translatorcommon.WrapGeminiCLIResponse(out)
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiCLITokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -38,7 +38,7 @@ import (
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base template
out := `{"model":"","instructions":"","input":[]}`
out := []byte(`{"model":"","instructions":"","input":[]}`)
root := gjson.ParseBytes(rawJSON)
@@ -82,24 +82,24 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// System instruction -> as a user message with input_text parts
sysParts := root.Get("system_instruction.parts")
if sysParts.IsArray() {
msg := `{"type":"message","role":"developer","content":[]}`
msg := []byte(`{"type":"message","role":"developer","content":[]}`)
arr := sysParts.Array()
for i := 0; i < len(arr); i++ {
p := arr[i]
if t := p.Get("text"); t.Exists() {
part := `{}`
part, _ = sjson.Set(part, "type", "input_text")
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
}
if len(gjson.Get(msg, "content").Array()) > 0 {
out, _ = sjson.SetRaw(out, "input.-1", msg)
if len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
}
}
@@ -123,23 +123,23 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
p := parr[j]
// text part
if t := p.Get("text"); t.Exists() {
msg := `{"type":"message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
msg := []byte(`{"type":"message","role":"","content":[]}`)
msg, _ = sjson.SetBytes(msg, "role", role)
partType := "input_text"
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
out, _ = sjson.SetRaw(out, "input.-1", msg)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", t.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
continue
}
// function call from model
if fc := p.Get("functionCall"); fc.Exists() {
fn := `{"type":"function_call"}`
fn := []byte(`{"type":"function_call"}`)
if name := fc.Get("name"); name.Exists() {
n := name.String()
if short, ok := shortMap[n]; ok {
@@ -147,31 +147,31 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
n = shortenNameIfNeeded(n)
}
fn, _ = sjson.Set(fn, "name", n)
fn, _ = sjson.SetBytes(fn, "name", n)
}
if args := fc.Get("args"); args.Exists() {
fn, _ = sjson.Set(fn, "arguments", args.Raw)
fn, _ = sjson.SetBytes(fn, "arguments", args.Raw)
}
// generate a paired random call_id and enqueue it so the
// corresponding functionResponse can pop the earliest id
// to preserve ordering when multiple calls are present.
id := genCallID()
fn, _ = sjson.Set(fn, "call_id", id)
fn, _ = sjson.SetBytes(fn, "call_id", id)
pendingCallIDs = append(pendingCallIDs, id)
out, _ = sjson.SetRaw(out, "input.-1", fn)
out, _ = sjson.SetRawBytes(out, "input.-1", fn)
continue
}
// function response from user
if fr := p.Get("functionResponse"); fr.Exists() {
fno := `{"type":"function_call_output"}`
fno := []byte(`{"type":"function_call_output"}`)
// Prefer a string result if present; otherwise embed the raw response as a string
if res := fr.Get("response.result"); res.Exists() {
fno, _ = sjson.Set(fno, "output", res.String())
fno, _ = sjson.SetBytes(fno, "output", res.String())
} else if resp := fr.Get("response"); resp.Exists() {
fno, _ = sjson.Set(fno, "output", resp.Raw)
fno, _ = sjson.SetBytes(fno, "output", resp.Raw)
}
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// fno, _ = sjson.SetBytes(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// attach the oldest queued call_id to pair the response
// with its call. If the queue is empty, generate a new id.
var id string
@@ -182,8 +182,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
id = genCallID()
}
fno, _ = sjson.Set(fno, "call_id", id)
out, _ = sjson.SetRaw(out, "input.-1", fno)
fno, _ = sjson.SetBytes(fno, "call_id", id)
out, _ = sjson.SetRawBytes(out, "input.-1", fno)
continue
}
}
@@ -193,8 +193,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Tools mapping: Gemini functionDeclarations -> Codex tools
tools := root.Get("tools")
if tools.IsArray() {
out, _ = sjson.SetRaw(out, "tools", `[]`)
out, _ = sjson.Set(out, "tool_choice", "auto")
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
out, _ = sjson.SetBytes(out, "tool_choice", "auto")
tarr := tools.Array()
for i := 0; i < len(tarr); i++ {
td := tarr[i]
@@ -205,8 +205,8 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
farr := fns.Array()
for j := 0; j < len(farr); j++ {
fn := farr[j]
tool := `{}`
tool, _ = sjson.Set(tool, "type", "function")
tool := []byte(`{}`)
tool, _ = sjson.SetBytes(tool, "type", "function")
if v := fn.Get("name"); v.Exists() {
name := v.String()
if short, ok := shortMap[name]; ok {
@@ -214,32 +214,32 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
tool, _ = sjson.SetBytes(tool, "name", name)
}
if v := fn.Get("description"); v.Exists() {
tool, _ = sjson.Set(tool, "description", v.String())
tool, _ = sjson.SetBytes(tool, "description", v.String())
}
if prm := fn.Get("parameters"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
cleaned := []byte(prm.Raw)
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
cleaned := []byte(prm.Raw)
cleaned, _ = sjson.DeleteBytes(cleaned, "$schema")
cleaned, _ = sjson.SetBytes(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRawBytes(tool, "parameters", cleaned)
}
tool, _ = sjson.Set(tool, "strict", false)
out, _ = sjson.SetRaw(out, "tools.-1", tool)
tool, _ = sjson.SetBytes(tool, "strict", false)
out, _ = sjson.SetRawBytes(out, "tools.-1", tool)
}
}
}
// Fixed flags aligning with Codex expectations
out, _ = sjson.Set(out, "parallel_tool_calls", true)
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
// Convert Gemini thinkingConfig to Codex reasoning.effort.
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if thinkingLevel.Exists() {
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if effort != "" {
out, _ = sjson.Set(out, "reasoning.effort", effort)
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true
}
} else {
@@ -263,7 +263,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
if thinkingBudget.Exists() {
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
out, _ = sjson.Set(out, "reasoning.effort", effort)
out, _ = sjson.SetBytes(out, "reasoning.effort", effort)
effortSet = true
}
}
@@ -272,22 +272,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
}
if !effortSet {
// No thinking config, set default effort
out, _ = sjson.Set(out, "reasoning.effort", "medium")
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "stream", true)
out, _ = sjson.Set(out, "store", false)
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.SetBytes(out, "stream", true)
out, _ = sjson.SetBytes(out, "store", false)
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
toolsResult := gjson.GetBytes(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
}
return []byte(out)
return out
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -7,9 +7,9 @@ package gemini
import (
"bytes"
"context"
"fmt"
"time"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -23,7 +23,7 @@ type ConvertCodexResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
LastStorageOutput []byte
}
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
@@ -38,19 +38,19 @@ type ConvertCodexResponseToGeminiParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of Gemini-compatible JSON responses
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToGeminiParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
LastStorageOutput: nil,
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
@@ -59,17 +59,17 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
typeStr := typeResult.String()
// Base Gemini response template
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" {
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`)
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 && typeStr == "response.output_item.done" {
template = append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...)
} else {
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
template, _ = sjson.SetBytes(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
createdAtResult := rootResult.Get("response.created_at")
if createdAtResult.Exists() {
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
}
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
template, _ = sjson.SetBytes(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
}
// Handle function call completion
@@ -78,7 +78,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
// Create function call part
functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall := []byte(`{"functionCall":{"name":"","args":{}}}`)
{
// Restore original tool name if shortened
n := itemResult.Get("name").String()
@@ -86,7 +86,7 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
}
// Parse and set arguments
@@ -94,47 +94,48 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
if argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
}
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = append([]byte(nil), template...)
// Use this return to storage message
return []string{}
return [][]byte{}
}
}
if typeStr == "response.created" { // Handle response creation - set model and response ID
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
template, _ = sjson.SetBytes(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.SetBytes(template, "responseId", rootResult.Get("response.id").String())
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
part := `{"thought":true,"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"thought":true,"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
part := `{"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
} else {
return []string{}
return [][]byte{}
}
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" {
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template}
} else {
return []string{template}
if len((*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput) > 0 {
return [][]byte{
append([]byte(nil), (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput...),
template,
}
}
return [][]byte{template}
}
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
@@ -149,32 +150,32 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
// Base Gemini response template for non-streaming
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
template := []byte(`{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`)
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
template, _ = sjson.SetBytes(template, "modelVersion", modelName)
// Set response metadata from the completed response
responseData := rootResult.Get("response")
if responseData.Exists() {
// Set response ID
if responseId := responseData.Get("id"); responseId.Exists() {
template, _ = sjson.Set(template, "responseId", responseId.String())
template, _ = sjson.SetBytes(template, "responseId", responseId.String())
}
// Set creation time
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
template, _ = sjson.SetBytes(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
}
// Set usage metadata
@@ -183,14 +184,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
outputTokens := usage.Get("output_tokens").Int()
totalTokens := inputTokens + outputTokens
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.SetBytes(template, "usageMetadata.totalTokenCount", totalTokens)
}
// Process output content to build parts array
hasToolCall := false
var pendingFunctionCalls []string
var pendingFunctionCalls [][]byte
flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) == 0 {
@@ -199,7 +200,7 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
}
@@ -215,9 +216,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content
if content := value.Get("content"); content.Exists() {
part := `{"text":"","thought":true}`
part, _ = sjson.Set(part, "text", content.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":"","thought":true}`)
part, _ = sjson.SetBytes(part, "text", content.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
}
case "message":
@@ -229,9 +230,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", text.String())
template, _ = sjson.SetRawBytes(template, "candidates.0.content.parts.-1", part)
}
}
return true
@@ -241,21 +242,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call":
// Collect function call for potential merging with consecutive ones
hasToolCall = true
functionCall := `{"functionCall":{"args":{},"name":""}}`
functionCall := []byte(`{"functionCall":{"args":{},"name":""}}`)
{
n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
functionCall, _ = sjson.SetBytes(functionCall, "functionCall.name", n)
}
// Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
functionCall, _ = sjson.SetRawBytes(functionCall, "functionCall.args", []byte(argsStr))
}
}
@@ -270,9 +271,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Set finish reason based on whether there were tool calls
if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
} else {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
template, _ = sjson.SetBytes(template, "candidates.0.finishReason", "STOP")
}
}
return template
@@ -307,6 +308,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
func GeminiTokenCount(ctx context.Context, count int64) []byte {
return translatorcommon.GeminiTokenCountJSON(count)
}

View File

@@ -29,42 +29,42 @@ import (
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Start with empty JSON object
out := `{"instructions":""}`
out := []byte(`{"instructions":""}`)
// Stream must be set to true
out, _ = sjson.Set(out, "stream", stream)
out, _ = sjson.SetBytes(out, "stream", stream)
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
// out, _ = sjson.Set(out, "temperature", v.Value())
// out, _ = sjson.SetBytes(out, "temperature", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
// out, _ = sjson.Set(out, "top_p", v.Value())
// out, _ = sjson.SetBytes(out, "top_p", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
// out, _ = sjson.Set(out, "top_k", v.Value())
// out, _ = sjson.SetBytes(out, "top_k", v.Value())
// }
// Map token limits
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// out, _ = sjson.SetBytes(out, "max_output_tokens", v.Value())
// }
// Map reasoning effort
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
out, _ = sjson.SetBytes(out, "reasoning.effort", v.Value())
} else {
out, _ = sjson.Set(out, "reasoning.effort", "medium")
out, _ = sjson.SetBytes(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "parallel_tool_calls", true)
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
out, _ = sjson.SetBytes(out, "parallel_tool_calls", true)
out, _ = sjson.SetBytes(out, "reasoning.summary", "auto")
out, _ = sjson.SetBytes(out, "include", []string{"reasoning.encrypted_content"})
// Model
out, _ = sjson.Set(out, "model", modelName)
out, _ = sjson.SetBytes(out, "model", modelName)
// Build tool name shortening map from original tools (if any)
originalToolNameMap := map[string]string{}
@@ -100,9 +100,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// if m.Get("role").String() == "system" {
// c := m.Get("content")
// if c.Type == gjson.String {
// out, _ = sjson.Set(out, "instructions", c.String())
// out, _ = sjson.SetBytes(out, "instructions", c.String())
// } else if c.IsObject() && c.Get("type").String() == "text" {
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
// out, _ = sjson.SetBytes(out, "instructions", c.Get("text").String())
// }
// break
// }
@@ -110,7 +110,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// }
// Build input from messages, handling all message types including tool calls
out, _ = sjson.SetRaw(out, "input", `[]`)
out, _ = sjson.SetRawBytes(out, "input", []byte(`[]`))
if messages.IsArray() {
arr := messages.Array()
for i := 0; i < len(arr); i++ {
@@ -124,23 +124,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
content := m.Get("content").String()
// Create function_call_output object
funcOutput := `{}`
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.Set(funcOutput, "output", content)
out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
funcOutput := []byte(`{}`)
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default:
// Handle regular messages
msg := `{}`
msg, _ = sjson.Set(msg, "type", "message")
msg := []byte(`{}`)
msg, _ = sjson.SetBytes(msg, "type", "message")
if role == "system" {
msg, _ = sjson.Set(msg, "role", "developer")
msg, _ = sjson.SetBytes(msg, "role", "developer")
} else {
msg, _ = sjson.Set(msg, "role", role)
msg, _ = sjson.SetBytes(msg, "role", role)
}
msg, _ = sjson.SetRaw(msg, "content", `[]`)
msg, _ = sjson.SetRawBytes(msg, "content", []byte(`[]`))
// Handle regular content
c := m.Get("content")
@@ -150,10 +150,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", c.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", c.String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
} else if c.Exists() && c.IsArray() {
items := c.Array()
for j := 0; j < len(items); j++ {
@@ -165,39 +165,44 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", it.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", partType)
part, _ = sjson.SetBytes(part, "text", it.Get("text").String())
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
case "image_url":
// Map image inputs to input_image for Responses API
if role == "user" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_image")
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_image")
if u := it.Get("image_url.url"); u.Exists() {
part, _ = sjson.Set(part, "image_url", u.String())
part, _ = sjson.SetBytes(part, "image_url", u.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
case "file":
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_file")
part, _ = sjson.SetBytes(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
part, _ = sjson.SetBytes(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
msg, _ = sjson.SetRawBytes(msg, "content.-1", part)
}
}
}
}
}
out, _ = sjson.SetRaw(out, "input.-1", msg)
// Don't emit empty assistant messages when only tool_calls
// are present — Responses API needs function_call items
// directly, otherwise call_id matching fails (#2132).
if role != "assistant" || len(gjson.GetBytes(msg, "content").Array()) > 0 {
out, _ = sjson.SetRawBytes(out, "input.-1", msg)
}
// Handle tool calls for assistant messages as separate top-level objects
if role == "assistant" {
@@ -208,9 +213,9 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
tc := toolCallsArr[j]
if tc.Get("type").String() == "function" {
// Create function_call as top-level object
funcCall := `{}`
funcCall, _ = sjson.Set(funcCall, "type", "function_call")
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
funcCall := []byte(`{}`)
funcCall, _ = sjson.SetBytes(funcCall, "type", "function_call")
funcCall, _ = sjson.SetBytes(funcCall, "call_id", tc.Get("id").String())
{
name := tc.Get("function.name").String()
if short, ok := originalToolNameMap[name]; ok {
@@ -218,10 +223,10 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else {
name = shortenNameIfNeeded(name)
}
funcCall, _ = sjson.Set(funcCall, "name", name)
funcCall, _ = sjson.SetBytes(funcCall, "name", name)
}
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRaw(out, "input.-1", funcCall)
funcCall, _ = sjson.SetBytes(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRawBytes(out, "input.-1", funcCall)
}
}
}
@@ -235,26 +240,26 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
text := gjson.GetBytes(rawJSON, "text")
if rf.Exists() {
// Always create text object when response_format provided
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
}
rft := rf.Get("type").String()
switch rft {
case "text":
out, _ = sjson.Set(out, "text.format.type", "text")
out, _ = sjson.SetBytes(out, "text.format.type", "text")
case "json_schema":
js := rf.Get("json_schema")
if js.Exists() {
out, _ = sjson.Set(out, "text.format.type", "json_schema")
out, _ = sjson.SetBytes(out, "text.format.type", "json_schema")
if v := js.Get("name"); v.Exists() {
out, _ = sjson.Set(out, "text.format.name", v.Value())
out, _ = sjson.SetBytes(out, "text.format.name", v.Value())
}
if v := js.Get("strict"); v.Exists() {
out, _ = sjson.Set(out, "text.format.strict", v.Value())
out, _ = sjson.SetBytes(out, "text.format.strict", v.Value())
}
if v := js.Get("schema"); v.Exists() {
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
out, _ = sjson.SetRawBytes(out, "text.format.schema", []byte(v.Raw))
}
}
}
@@ -262,23 +267,23 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Map verbosity if provided
if text.Exists() {
if v := text.Get("verbosity"); v.Exists() {
out, _ = sjson.Set(out, "text.verbosity", v.Value())
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
}
}
} else if text.Exists() {
// If only text.verbosity present (no response_format), map verbosity
if v := text.Get("verbosity"); v.Exists() {
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
if !gjson.GetBytes(out, "text").Exists() {
out, _ = sjson.SetRawBytes(out, "text", []byte(`{}`))
}
out, _ = sjson.Set(out, "text.verbosity", v.Value())
out, _ = sjson.SetBytes(out, "text.verbosity", v.Value())
}
}
// Map tools (flatten function fields)
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", `[]`)
out, _ = sjson.SetRawBytes(out, "tools", []byte(`[]`))
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
@@ -286,13 +291,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
out, _ = sjson.SetRawBytes(out, "tools.-1", []byte(t.Raw))
continue
}
if toolType == "function" {
item := `{}`
item, _ = sjson.Set(item, "type", "function")
item := []byte(`{}`)
item, _ = sjson.SetBytes(item, "type", "function")
fn := t.Get("function")
if fn.Exists() {
if v := fn.Get("name"); v.Exists() {
@@ -302,19 +307,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
} else {
name = shortenNameIfNeeded(name)
}
item, _ = sjson.Set(item, "name", name)
item, _ = sjson.SetBytes(item, "name", name)
}
if v := fn.Get("description"); v.Exists() {
item, _ = sjson.Set(item, "description", v.Value())
item, _ = sjson.SetBytes(item, "description", v.Value())
}
if v := fn.Get("parameters"); v.Exists() {
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
item, _ = sjson.SetRawBytes(item, "parameters", []byte(v.Raw))
}
if v := fn.Get("strict"); v.Exists() {
item, _ = sjson.Set(item, "strict", v.Value())
item, _ = sjson.SetBytes(item, "strict", v.Value())
}
}
out, _ = sjson.SetRaw(out, "tools.-1", item)
out, _ = sjson.SetRawBytes(out, "tools.-1", item)
}
}
}
@@ -325,7 +330,7 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch {
case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String())
out, _ = sjson.SetBytes(out, "tool_choice", tc.String())
case tc.IsObject():
tcType := tc.Get("type").String()
if tcType == "function" {
@@ -337,21 +342,21 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
name = shortenNameIfNeeded(name)
}
}
choice := `{}`
choice, _ = sjson.Set(choice, "type", "function")
choice := []byte(`{}`)
choice, _ = sjson.SetBytes(choice, "type", "function")
if name != "" {
choice, _ = sjson.Set(choice, "name", name)
choice, _ = sjson.SetBytes(choice, "name", name)
}
out, _ = sjson.SetRaw(out, "tool_choice", choice)
out, _ = sjson.SetRawBytes(out, "tool_choice", choice)
} else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(tc.Raw))
}
}
}
out, _ = sjson.Set(out, "store", false)
return []byte(out)
out, _ = sjson.SetBytes(out, "store", false)
return out
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.

View File

@@ -0,0 +1,635 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
// Expects developer msg + user msg + function_call + function_call_output.
// No empty assistant message should appear between user and function_call.
func TestToolCallSimple(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "sunny, 22C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
// system -> developer
if items[0].Get("type").String() != "message" {
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
}
if items[0].Get("role").String() != "developer" {
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
}
// user
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "user" {
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
}
// function_call, not an empty assistant msg
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_1" {
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
}
if items[2].Get("name").String() != "get_weather" {
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
}
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
}
// function_call_output
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_1" {
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
}
if items[3].Get("output").String() != "sunny, 22C" {
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
}
}
// Assistant has both text content and tool_calls — the message should
// be emitted (non-empty content), followed by function_call items.
func TestToolCallWithContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "What is the weather?"},
{
"role": "assistant",
"content": "Let me check the weather for you.",
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc",
"content": "rainy, 15C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + assistant(with content) + function_call + function_call_output
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
// assistant with content — should be kept
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "assistant" {
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
}
contentParts := items[1].Get("content").Array()
if len(contentParts) == 0 {
t.Errorf("item 1: assistant message should have content parts")
}
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_abc" {
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
}
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_abc" {
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_paris",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
},
{
"id": "call_london",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"London\"}"
}
},
{
"id": "call_tokyo",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Tokyo\"}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + 3 function_call + 3 function_call_output = 7
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
for i, expectedID := range expectedCallIDs {
idx := i + 1
if items[idx].Get("type").String() != "function_call" {
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedID {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
}
}
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
for i, expectedOutput := range expectedOutputs {
idx := i + 4
if items[idx].Get("type").String() != "function_call_output" {
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
}
if items[idx].Get("output").String() != expectedOutput {
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
}
}
}
// Regression test for #2132: tool-call-only assistant messages (content:null)
// must not produce an empty message item in the translated output.
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Call a tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_x",
"type": "function",
"function": {"name": "do_thing", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
],
"tools": [
{
"type": "function",
"function": {
"name": "do_thing",
"description": "Do a thing",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
typ := item.Get("type").String()
role := item.Get("role").String()
if typ == "message" && role == "assistant" {
contentArr := item.Get("content").Array()
if len(contentArr) == 0 {
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
}
}
}
// should be exactly: user + function_call + function_call_output
if len(items) != 3 {
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected user message")
}
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
}
// Two rounds of tool calling in one conversation, with a text reply in between.
func TestMultiTurnToolCalling(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
{"role": "assistant", "content": "It is sunny in Paris."},
{"role": "user", "content": "And London?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: unexpected empty assistant message", i)
}
}
}
// round 1
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[1].Get("call_id").String() != "call_r1" {
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
// text reply between rounds
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
}
// round 2
if items[5].Get("type").String() != "function_call" {
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
}
if items[5].Get("call_id").String() != "call_r2" {
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
}
if items[6].Get("type").String() != "function_call_output" {
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
}
}
// Tool names over 64 chars get shortened, call_id stays the same.
func TestToolNameShortening(t *testing.T) {
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
if len(longName) <= 64 {
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
}
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do it"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_long",
"type": "function",
"function": {
"name": "` + longName + `",
"arguments": "{}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
],
"tools": [
{
"type": "function",
"function": {
"name": "` + longName + `",
"description": "A tool with a very long name",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// find function_call
var funcCallItem gjson.Result
for _, item := range items {
if item.Get("type").String() == "function_call" {
funcCallItem = item
break
}
}
if !funcCallItem.Exists() {
t.Fatal("no function_call item found in output")
}
// call_id unchanged
if funcCallItem.Get("call_id").String() != "call_long" {
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
}
// name must be truncated
translatedName := funcCallItem.Get("name").String()
if translatedName == longName {
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
}
if len(translatedName) > 64 {
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
}
}
// content:"" (empty string, not null) should be treated the same as null.
func TestEmptyStringContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_empty",
"type": "function",
"function": {"name": "action", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
],
"tools": [
{
"type": "function",
"function": {
"name": "action",
"description": "An action",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: empty assistant message from content:\"\"", i)
}
}
}
// user + function_call + function_call_output
if len(items) != 3 {
t.Errorf("expected 3 input items, got %d", len(items))
}
}
// Every function_call_output must have a matching function_call by call_id.
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Multi-tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
]
},
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
],
"tools": [
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// collect call_ids from function_call items
callIDs := make(map[string]bool)
for _, item := range items {
if item.Get("type").String() == "function_call" {
callIDs[item.Get("call_id").String()] = true
}
}
for i, item := range items {
if item.Get("type").String() == "function_call_output" {
outID := item.Get("call_id").String()
if !callIDs[outID] {
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
}
}
}
// 2 calls, 2 outputs
funcCallCount := 0
funcOutputCount := 0
for _, item := range items {
switch item.Get("type").String() {
case "function_call":
funcCallCount++
case "function_call_output":
funcOutputCount++
}
}
if funcCallCount != 2 {
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
}
if funcOutputCount != 2 {
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
}
}
// Tools array should carry over to the Responses format output.
func TestToolsDefinitionTranslated(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hi"}
],
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
tools := gjson.Get(result, "tools").Array()
if len(tools) == 0 {
t.Fatal("no tools found in output")
}
found := false
for _, tool := range tools {
if tool.Get("name").String() == "search" {
found = true
break
}
}
if !found {
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
}
}

View File

@@ -41,8 +41,8 @@ type ConvertCliToOpenAIParams struct {
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// - [][]byte: A slice of OpenAI-compatible JSON responses
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
@@ -55,12 +55,12 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
return [][]byte{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{},"finish_reason":null,"native_finish_reason":null}]}`)
rootResult := gjson.ParseBytes(rawJSON)
@@ -70,67 +70,67 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
return []string{}
return [][]byte{}
}
// Extract and set the model version.
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
template, _ = sjson.SetBytes(template, "model", modelResult.String())
} else if cachedModel != "" {
template, _ = sjson.Set(template, "model", cachedModel)
template, _ = sjson.SetBytes(template, "model", cachedModel)
} else if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.SetBytes(template, "model", modelName)
}
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
template, _ = sjson.SetBytes(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
// Extract and set the response ID.
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
template, _ = sjson.SetBytes(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
if dataType == "response.reasoning_summary_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", deltaResult.String())
}
} else if dataType == "response.reasoning_summary_text.done" {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.reasoning_content", "\n\n")
} else if dataType == "response.output_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.delta.content", deltaResult.String())
}
} else if dataType == "response.completed" {
finishReason := "stop"
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
finishReason = "tool_calls"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
return [][]byte{}
}
// Increment index for this new function call item.
@@ -138,9 +138,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
@@ -148,59 +148,59 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
return [][]byte{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
functionCallItemTemplate := []byte(`{"index":0,"function":{"arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
return [][]byte{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
return [][]byte{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate := []byte(`{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls", []byte(`[]`))
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
@@ -208,17 +208,17 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
functionCallItemTemplate, _ = sjson.SetBytes(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.SetBytes(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRawBytes(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
return [][]byte{}
}
return []string{template}
return [][]byte{template}
}
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
@@ -233,53 +233,53 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// - []byte: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
unixTimestamp := time.Now().Unix()
responseResult := rootResult.Get("response")
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
template := []byte(`{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`)
// Extract and set the model version.
if modelResult := responseResult.Get("model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
template, _ = sjson.SetBytes(template, "model", modelResult.String())
}
// Extract and set the creation timestamp.
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
template, _ = sjson.Set(template, "created", createdAtResult.Int())
template, _ = sjson.SetBytes(template, "created", createdAtResult.Int())
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
template, _ = sjson.SetBytes(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if idResult := responseResult.Get("id"); idResult.Exists() {
template, _ = sjson.Set(template, "id", idResult.String())
template, _ = sjson.SetBytes(template, "id", idResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := responseResult.Get("usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
template, _ = sjson.SetBytes(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
@@ -289,7 +289,7 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
outputArray := outputResult.Array()
var contentText string
var reasoningText string
var toolCalls []string
var toolCalls [][]byte
for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String()
@@ -319,10 +319,10 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
}
case "function_call":
// Handle function call content
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
functionCallTemplate := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "id", callIdResult.String())
}
if nameResult := outputItem.Get("name"); nameResult.Exists() {
@@ -331,11 +331,11 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if orig, ok := rev[n]; ok {
n = orig
}
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n)
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.name", n)
}
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
functionCallTemplate, _ = sjson.SetBytes(functionCallTemplate, "function.arguments", argsResult.String())
}
toolCalls = append(toolCalls, functionCallTemplate)
@@ -344,22 +344,22 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
// Set content and reasoning content if found
if contentText != "" {
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.content", contentText)
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
if reasoningText != "" {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
// Add tool calls if any
if len(toolCalls) > 0 {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls", []byte(`[]`))
for _, toolCall := range toolCalls {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
template, _ = sjson.SetRawBytes(template, "choices.0.message.tool_calls.-1", toolCall)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetBytes(template, "choices.0.message.role", "assistant")
}
}
@@ -367,8 +367,8 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String()
if status == "completed" {
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.finish_reason", "stop")
template, _ = sjson.SetBytes(template, "choices.0.native_finish_reason", "stop")
}
}

View File

@@ -23,7 +23,7 @@ func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *test
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
@@ -40,8 +40,53 @@ func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
gotModel := gjson.GetBytes(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
}
func TestConvertCodexResponseToOpenAI_ToolCallChunkOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls").Exists() {
t.Fatalf("expected tool_calls to exist, got %s", string(out[0]))
}
}
func TestConvertCodexResponseToOpenAI_ToolCallArgumentsDeltaOmitsNullContentFields(t *testing.T) {
ctx := context.Background()
var param any
out := ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"call_123","name":"websearch"}}`), &param)
if len(out) != 1 {
t.Fatalf("expected tool call announcement chunk, got %d", len(out))
}
out = ConvertCodexResponseToOpenAI(ctx, "gpt-5.4", nil, nil, []byte(`data: {"type":"response.function_call_arguments.delta","delta":"{\"query\":\"OpenAI\"}"}`), &param)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gjson.GetBytes(out[0], "choices.0.delta.content").Exists() {
t.Fatalf("expected content to be omitted, got %s", string(out[0]))
}
if gjson.GetBytes(out[0], "choices.0.delta.reasoning_content").Exists() {
t.Fatalf("expected reasoning_content to be omitted, got %s", string(out[0]))
}
if !gjson.GetBytes(out[0], "choices.0.delta.tool_calls.0.function.arguments").Exists() {
t.Fatalf("expected tool call arguments delta to exist, got %s", string(out[0]))
}
}

View File

@@ -3,6 +3,7 @@ package responses
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -12,8 +13,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
inputResult := gjson.GetBytes(rawJSON, "input")
if inputResult.Type == gjson.String {
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
input, _ := sjson.SetBytes([]byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`), "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", input)
}
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
@@ -39,6 +40,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
rawJSON = convertSystemRoleToDeveloper(rawJSON)
rawJSON = normalizeCodexBuiltinTools(rawJSON)
return rawJSON
}
@@ -82,3 +84,59 @@ func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
return result
}
// normalizeCodexBuiltinTools rewrites legacy/preview built-in tool variants to the
// stable names expected by the current Codex upstream.
func normalizeCodexBuiltinTools(rawJSON []byte) []byte {
result := rawJSON
tools := gjson.GetBytes(result, "tools")
if tools.IsArray() {
toolArray := tools.Array()
for i := 0; i < len(toolArray); i++ {
typePath := fmt.Sprintf("tools.%d.type", i)
result = normalizeCodexBuiltinToolAtPath(result, typePath)
}
}
result = normalizeCodexBuiltinToolAtPath(result, "tool_choice.type")
toolChoiceTools := gjson.GetBytes(result, "tool_choice.tools")
if toolChoiceTools.IsArray() {
toolArray := toolChoiceTools.Array()
for i := 0; i < len(toolArray); i++ {
typePath := fmt.Sprintf("tool_choice.tools.%d.type", i)
result = normalizeCodexBuiltinToolAtPath(result, typePath)
}
}
return result
}
func normalizeCodexBuiltinToolAtPath(rawJSON []byte, path string) []byte {
currentType := gjson.GetBytes(rawJSON, path).String()
normalizedType := normalizeCodexBuiltinToolType(currentType)
if normalizedType == "" {
return rawJSON
}
updated, err := sjson.SetBytes(rawJSON, path, normalizedType)
if err != nil {
return rawJSON
}
log.Debugf("codex responses: normalized builtin tool type at %s from %q to %q", path, currentType, normalizedType)
return updated
}
// normalizeCodexBuiltinToolType centralizes the current known Codex Responses
// built-in tool alias compatibility. If Codex introduces more legacy aliases,
// extend this helper instead of adding path-specific rewrite logic elsewhere.
func normalizeCodexBuiltinToolType(toolType string) string {
switch toolType {
case "web_search_preview", "web_search_preview_2025_03_11":
return "web_search"
default:
return ""
}
}

View File

@@ -264,6 +264,52 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
}
}
func TestConvertOpenAIResponsesRequestToCodex_NormalizesWebSearchPreview(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.4-mini",
"input": "find latest OpenAI model news",
"tools": [
{"type": "web_search_preview_2025_03_11"}
],
"tool_choice": {
"type": "allowed_tools",
"tools": [
{"type": "web_search_preview"},
{"type": "web_search_preview_2025_03_11"}
]
}
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
if got := gjson.GetBytes(output, "tools.0.type").String(); got != "web_search" {
t.Fatalf("tools.0.type = %q, want %q: %s", got, "web_search", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "allowed_tools" {
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "allowed_tools", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.tools.0.type").String(); got != "web_search" {
t.Fatalf("tool_choice.tools.0.type = %q, want %q: %s", got, "web_search", string(output))
}
if got := gjson.GetBytes(output, "tool_choice.tools.1.type").String(); got != "web_search" {
t.Fatalf("tool_choice.tools.1.type = %q, want %q: %s", got, "web_search", string(output))
}
}
func TestConvertOpenAIResponsesRequestToCodex_NormalizesTopLevelToolChoicePreviewAlias(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.4-mini",
"input": "find latest OpenAI model news",
"tool_choice": {"type": "web_search_preview_2025_03_11"}
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.4-mini", inputJSON, false)
if got := gjson.GetBytes(output, "tool_choice.type").String(); got != "web_search" {
t.Fatalf("tool_choice.type = %q, want %q: %s", got, "web_search", string(output))
}
}
func TestUserFieldDeletion(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",

View File

@@ -3,7 +3,6 @@ package responses
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
)
@@ -11,23 +10,25 @@ import (
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) [][]byte {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
out := fmt.Sprintf("data: %s", string(rawJSON))
return []string{out}
out := make([]byte, 0, len(rawJSON)+len("data: "))
out = append(out, []byte("data: ")...)
out = append(out, rawJSON...)
return [][]byte{out}
}
return []string{string(rawJSON)}
return [][]byte{rawJSON}
}
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response.
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []byte {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
return []byte{}
}
responseResult := rootResult.Get("response")
return responseResult.Raw
return []byte(responseResult.Raw)
}

View File

@@ -0,0 +1,67 @@
package common
import (
"strconv"
"github.com/tidwall/sjson"
)
func WrapGeminiCLIResponse(response []byte) []byte {
out, err := sjson.SetRawBytes([]byte(`{"response":{}}`), "response", response)
if err != nil {
return response
}
return out
}
func GeminiTokenCountJSON(count int64) []byte {
out := make([]byte, 0, 96)
out = append(out, `{"totalTokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `,"promptTokensDetails":[{"modality":"TEXT","tokenCount":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, `}]}`...)
return out
}
func ClaudeInputTokensJSON(count int64) []byte {
out := make([]byte, 0, 32)
out = append(out, `{"input_tokens":`...)
out = strconv.AppendInt(out, count, 10)
out = append(out, '}')
return out
}
func SSEEventData(event string, payload []byte) []byte {
out := make([]byte, 0, len(event)+len(payload)+14)
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
return out
}
func AppendSSEEventString(out []byte, event, payload string, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}
func AppendSSEEventBytes(out []byte, event string, payload []byte, trailingNewlines int) []byte {
out = append(out, "event: "...)
out = append(out, event...)
out = append(out, '\n')
out = append(out, "data: "...)
out = append(out, payload...)
for i := 0; i < trailingNewlines; i++ {
out = append(out, '\n')
}
return out
}

Some files were not shown because too many files have changed in this diff Show More