Compare commits

...

166 Commits

Author SHA1 Message Date
Luis Pater
19232c6388 Merge branch 'router-for-me:main' into main 2025-12-31 11:52:49 +08:00
Luis Pater
e998b1229a feat(updater): add fallback URL and logic for missing management asset 2025-12-31 11:51:20 +08:00
Luis Pater
bbed134bd1 feat(api): add GetAuthStatus method to ManagementTokenRequester interface 2025-12-31 09:40:48 +08:00
Luis Pater
06075c0e93 Merge pull request #75 from router-for-me/plus
v6.6.71
2025-12-30 23:34:44 +08:00
Luis Pater
e85c9d9322 Merge branch 'main' into plus 2025-12-30 23:33:35 +08:00
Chén Mù
cb56cb250e Merge pull request #800 from router-for-me/modelmappings
feat(watcher): add model mappings change detection
2025-12-30 06:50:42 -08:00
hkfires
e0381a6ae0 refactor(watcher): extract model summary functions to dedicated file 2025-12-30 22:39:12 +08:00
hkfires
2c01b2ef64 feat(watcher): add Gemini models and OAuth model mappings change detection 2025-12-30 22:39:12 +08:00
Chén Mù
e947266743 Merge pull request #795 from router-for-me/modelmappings
refactor(executor): resolve upstream model at conductor level before execution
2025-12-30 05:31:19 -08:00
Luis Pater
c6b0e85b54 Fixed: #790
fix(gemini): include full text in response output events
2025-12-30 20:44:13 +08:00
hkfires
26efbed05c refactor(executor): remove redundant upstream model parameter from translateRequest 2025-12-30 20:20:42 +08:00
hkfires
96340bf136 refactor(executor): resolve upstream model at conductor level before execution 2025-12-30 19:31:54 +08:00
hkfires
b055e00c1a fix(executor): use upstream model for thinking config and payload translation 2025-12-30 17:49:44 +08:00
Chén Mù
857c880f99 Merge pull request #785 from router-for-me/gemini
feat(gemini): add per-key model alias support for Gemini provider
2025-12-29 23:32:40 -08:00
hkfires
ce7474d953 feat(cliproxy): propagate thinking support metadata to aliased models 2025-12-30 15:16:54 +08:00
hkfires
70fdd70b84 refactor(cliproxy): extract generic buildConfigModels function for model info generation 2025-12-30 13:35:22 +08:00
hkfires
08ab6a7d77 feat(gemini): add per-key model alias support for Gemini provider 2025-12-30 13:27:57 +08:00
Luis Pater
5c95129884 Merge branch 'router-for-me:main' into main 2025-12-30 11:48:29 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
1a99cfded4 Merge branch 'router-for-me:main' into main 2025-12-30 00:38:41 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
Luis Pater
cf369d4684 Merge branch 'router-for-me:main' into main 2025-12-29 22:41:44 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
e3d8d726e6 Merge branch 'router-for-me:main' into main 2025-12-28 15:09:33 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
0f51e73baa Merge branch 'router-for-me:main' into main 2025-12-28 03:07:58 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
d06e2dc83c Merge branch 'router-for-me:main' into main 2025-12-28 02:10:16 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
Luis Pater
790a17ce98 Merge pull request #70 from router-for-me/plus
v6.6.60
2025-12-28 00:57:14 +08:00
Luis Pater
d473c952fb Merge branch 'main' into plus 2025-12-28 00:56:04 +08:00
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
d35152bbef Merge branch 'router-for-me:main' into main 2025-12-27 22:03:50 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
60936b5185 Merge branch 'router-for-me:main' into main 2025-12-27 03:57:03 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
b7f7b3a1d8 Merge branch 'router-for-me:main' into main 2025-12-27 01:26:33 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
618606966f Merge pull request #68 from router-for-me/plus
v6.6.56
2025-12-26 12:14:42 +08:00
Luis Pater
05f249d77f Merge branch 'main' into plus 2025-12-26 12:14:35 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
9fe6a215e6 Merge branch 'router-for-me:main' into main 2025-12-26 05:10:19 +08:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
cf8b2dcc85 Merge pull request #67 from router-for-me/plus
v6.6.54
2025-12-25 21:07:45 +08:00
Luis Pater
8e24d9dc34 Merge branch 'main' into plus 2025-12-25 21:07:37 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
ed57d82bc1 Merge branch 'router-for-me:main' into main 2025-12-24 23:31:09 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
Luis Pater
7af5a90a0b Merge pull request #65 from router-for-me/plus
v6.6.52
2025-12-24 22:28:45 +08:00
Luis Pater
7551faff79 Merge branch 'main' into plus 2025-12-24 22:27:12 +08:00
Luis Pater
b7409dd2de Merge pull request #706 from router-for-me/log
Log
2025-12-24 22:24:39 +08:00
hkfires
5ba325a8fc refactor(logging): standardize request id formatting and layout 2025-12-24 22:03:07 +08:00
Luis Pater
d502840f91 Merge pull request #695 from NguyenSiTrung/main
feat: add cached token parsing for Gemini , Antigravity API responses
2025-12-24 21:58:55 +08:00
hkfires
99238a4b59 fix(logging): normalize warning level to warn 2025-12-24 21:11:37 +08:00
hkfires
6d43a2ff9a refactor(logging): inline request id in log output 2025-12-24 21:07:18 +08:00
Luis Pater
cdb9c2e6e8 Merge remote-tracking branch 'origin/main' into router-for-me/main 2025-12-24 20:18:53 +08:00
Luis Pater
3faa1ca9af Merge pull request #700 from router-for-me/log
refactor(sdk/auth): rename manager.go to conductor.go
2025-12-24 19:36:24 +08:00
Luis Pater
9d975e0375 feat(models): add support for GLM-4.7 and MiniMax-M2.1 2025-12-24 19:30:57 +08:00
hkfires
2a6d8b78d4 feat(api): add endpoint to retrieve request logs by ID 2025-12-24 19:24:51 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
Luis Pater
6b80ec79a0 Merge pull request #61 from tinyc0der/fix/kiro-tool-results
fix(kiro): Handle tool results correctly in OpenAI format translation
2025-12-24 17:23:41 +08:00
Luis Pater
d3f4783a24 Merge pull request #57 from PancakeZik/my-idc-changes
feat: add AWS Identity Center (IDC) authentication support
2025-12-24 17:20:01 +08:00
Luis Pater
1cb6bdbc87 Merge pull request #62 from router-for-me/plus
v6.6.50
2025-12-24 17:14:05 +08:00
Luis Pater
96ddfc1f24 Merge branch 'main' into plus 2025-12-24 17:13:51 +08:00
TinyCoder
c169b32570 refactor(kiro): Remove unused variables in OpenAI translator
Remove dead code that was never used:
- toolCallIDToName map: built but never read from
- seenToolCallIDs: declared but never populated, only suppressed with _
2025-12-24 15:10:35 +07:00
TinyCoder
36a512fdf2 fix(kiro): Handle tool results correctly in OpenAI format translation
Fix three issues in Kiro OpenAI translator that caused "Improperly formed request"
errors when processing LiteLLM-translated requests with tool_use/tool_result:

1. Skip merging tool role messages in MergeAdjacentMessages() to preserve
   individual tool_call_id fields

2. Track pendingToolResults and attach to the next user message instead of
   only the last message. Create synthetic user message when conversation
   ends with tool results.

3. Insert synthetic user message with tool results before assistant messages
   to maintain proper alternating user/assistant structure. This fixes the case
   where LiteLLM translates Anthropic user messages containing only tool_result
   blocks into tool role messages followed by assistant.

Adds unit tests covering all tool result handling scenarios.
2025-12-24 15:10:35 +07:00
hkfires
26fbb77901 refactor(sdk/auth): rename manager.go to conductor.go 2025-12-24 15:21:03 +08:00
NguyenSiTrung
a277302262 Merge remote-tracking branch 'upstream/main' 2025-12-24 10:54:09 +07:00
NguyenSiTrung
969c1a5b72 refactor: extract parseGeminiFamilyUsageDetail helper to reduce duplication 2025-12-24 10:22:31 +07:00
NguyenSiTrung
872339bceb feat: add cached token parsing for Gemini API responses 2025-12-24 10:20:11 +07:00
Luis Pater
5dc0dbc7aa Merge pull request #697 from Cubence-com/main
docs(readme): add Cubence sponsor and fix PackyCode link
2025-12-24 11:19:32 +08:00
Luis Pater
ee6fc4e8a1 Merge pull request #60 from router-for-me/pr-59-resolve-conflicts
v6.6.50(解决 #59 冲突)
2025-12-24 11:12:09 +08:00
Luis Pater
8fee16aecd Merge PR #59: v6.6.50 (resolve conflicts) 2025-12-24 11:06:10 +08:00
Luis Pater
2b7ba54a2f Merge pull request #688 from router-for-me/feature/request-id-tracking
feat(logging): implement request ID tracking and propagation
2025-12-24 10:54:13 +08:00
hkfires
007c3304f2 feat(logging): scope request ID tracking to AI API endpoints 2025-12-24 09:17:09 +08:00
hkfires
e76ba0ede9 feat(logging): implement request ID tracking and propagation 2025-12-24 08:32:17 +08:00
Luis Pater
c06ac07e23 Merge pull request #686 from ajkdrag/main
feat: regex support for model-mappings
2025-12-24 04:37:44 +08:00
Luis Pater
e592a57458 Merge branch 'router-for-me:main' into main 2025-12-24 04:25:06 +08:00
Luis Pater
66769ec657 fix(translators): update role from tool to user in Gemini and Gemini-CLI requests 2025-12-24 04:24:07 +08:00
Luis Pater
f413feec61 refactor(handlers): streamline error and data channel handling in streaming logic
Improved consistency across OpenAI, Claude, and Gemini handlers by replacing initial `select` statement with a `for` loop for better readability and error-handling robustness.
2025-12-24 04:07:24 +08:00
Luis Pater
2e538e3486 Merge pull request #661 from jroth1111/fix/streaming-bootstrap-forwarder
fix: improve streaming bootstrap and forwarding
2025-12-24 03:51:40 +08:00
Luis Pater
9617a7b0d6 Merge pull request #621 from dacsang97/fix/antigravity-prompt-caching
Antigravity Prompt Caching Fix
2025-12-24 03:50:25 +08:00
Luis Pater
7569320770 Merge branch 'dev' into fix/antigravity-prompt-caching 2025-12-24 03:49:46 +08:00
Fetters
8d25cf0d75 fix(readme): update PackyCode sponsorship link and remove redundant tbody 2025-12-23 23:44:40 +08:00
Fetters
64e85e7019 docs(readme): add Cubence sponsor 2025-12-23 23:30:57 +08:00
Luis Pater
a862984dca Merge pull request #58 from router-for-me/plus
v6.6.48
2025-12-23 22:34:15 +08:00
Luis Pater
f0365f0465 Merge branch 'main' into plus 2025-12-23 22:34:08 +08:00
Luis Pater
6d1e20e940 fix(claude_executor): update header logic for API key handling
Refined header assignment to use `x-api-key` for Anthropic API requests, ensuring correct authorization behavior based on request attributes and URL validation.
2025-12-23 22:30:25 +08:00
altamash
0c0aae1eac Robust change detection: replaced string concat with struct-based compare in hasModelMappingsChanged; removed boolTo01.
•  Performance: pre-allocate map and regex slice capacities in UpdateMappings.
   •  Verified with amp module tests (all passing)
2025-12-23 18:52:28 +05:30
altamash
5dcf7cb846 feat: regex support for model-mappings 2025-12-23 18:41:58 +05:30
Joao
349b2ba3af refactor: improve error handling and code quality
- Handle errors in promptInput instead of ignoring them
- Improve promptSelect to provide feedback on invalid input and re-prompt
- Use sentinel errors (ErrAuthorizationPending, ErrSlowDown) instead of
  string-based error checking with strings.Contains
- Move hardcoded x-amz-user-agent header to idcAmzUserAgent constant

Addresses code review feedback from Gemini Code Assist.
2025-12-23 10:20:14 +00:00
Joao
98db5aabd0 feat: persist refreshed IDC tokens to auth file
Add persistRefreshedAuth function to write refreshed tokens back to the
auth file after inline token refresh. This prevents repeated token
refreshes on every request when the token expires.

Changes:
- Add persistRefreshedAuth() to kiro_executor.go
- Call persist after all token refresh paths (401, 403, pre-request)
- Remove unused log import from sdk/auth/kiro.go
2025-12-23 10:00:14 +00:00
Luis Pater
e52b542e22 Merge pull request #684 from packyme/main
docs(readme): add PackyCode sponsor
2025-12-23 17:19:25 +08:00
SmallL-U
8f6abb8a86 fix(readme): correct closing tbody tag 2025-12-23 17:17:57 +08:00
SmallL-U
ed8eaae964 docs(readme): add PackyCode sponsor 2025-12-23 17:11:34 +08:00
Joao
7fd98f3556 feat: add IDC auth support with Kiro IDE headers 2025-12-23 08:18:10 +00:00
Luis Pater
e8de87ee90 Merge branch 'router-for-me:main' into main 2025-12-23 08:48:29 +08:00
Luis Pater
4e572ec8b9 fix(translators): handle string system instructions in Claude translators
Updated Antigravity, Gemini, and Gemini-CLI translators to process `systemResult` of type `string` for system instructions. Ensures properly formatted JSON with dynamic content assignment.
2025-12-23 08:44:36 +08:00
Luis Pater
6c7f18c448 Merge branch 'router-for-me:main' into main 2025-12-23 03:50:35 +08:00
Luis Pater
24bc9cba67 Fixed: #639
fix(antigravity): validate function arguments before serialization

Ensure `function.arguments` is a valid JSON before setting raw bytes, fallback to setting as parameterized content if invalid.
2025-12-23 03:49:45 +08:00
Luis Pater
97356b1a04 Merge branch 'router-for-me:main' into main 2025-12-23 03:17:44 +08:00
Luis Pater
1084b53fba Fixed: #655
refactor(antigravity): clean up tool key filtering and improve signature caching logic
2025-12-23 03:16:51 +08:00
Luis Pater
b1aecc2bf1 Merge branch 'router-for-me:main' into main 2025-12-23 02:49:37 +08:00
Luis Pater
83b90e106f refactor(antigravity): add sandbox URL constant and update base URLs routine 2025-12-23 02:47:56 +08:00
Luis Pater
f52114dab2 Merge branch 'router-for-me:main' into main 2025-12-23 02:26:33 +08:00
Luis Pater
5106caf641 Fixed: #654
feat: handle array input for system instructions in translators

Enhanced Gemini, Gemini-CLI, and Antigravity translators to process array content for system instructions. Adds support for assigning roles and handling multiple content parts dynamically.
2025-12-23 02:24:26 +08:00
Luis Pater
12370ee84e Merge branch 'router-for-me:main' into main 2025-12-22 22:53:29 +08:00
Luis Pater
b84ccc6e7a feat: add unit tests for routing strategies and implement dynamic selector updates
Added comprehensive tests for `FillFirstSelector` and `RoundRobinSelector` to ensure proper behavior, including deterministic, cyclical, and concurrent scenarios. Introduced dynamic routing strategy updates in `service.go`, normalizing strategies and seamlessly switching between `fill-first` and `round-robin`. Updated `Manager` to support selector changes via the new `SetSelector` method.
2025-12-22 22:52:23 +08:00
Luis Pater
e19ddb53e7 Merge pull request #663 from jroth1111/feat/fill-first-selector
feat: add fill-first routing strategy
2025-12-22 22:26:32 +08:00
gwizz
5bf89dd757 fix: keep streaming defaults legacy-safe 2025-12-23 00:53:18 +11:00
gwizz
2a0100b2d6 docs: add routing strategy example 2025-12-23 00:39:18 +11:00
gwizz
4442574e53 fix: stop streaming loop on context cancel 2025-12-23 00:37:55 +11:00
gwizz
c020fa60d0 fix: keep round-robin as default routing 2025-12-22 23:39:41 +11:00
gwizz
b078be4613 feat: add fill-first routing strategy 2025-12-22 23:38:10 +11:00
gwizz
71a6dffbb6 fix: improve streaming bootstrap and forwarding 2025-12-22 23:34:23 +11:00
Luis Pater
5f65dd5bb4 Merge branch 'router-for-me:main' into main 2025-12-22 16:58:26 +08:00
Luis Pater
27b43ed63f Merge pull request #658 from moxi000/fix-responses-convert
Fix responses-format handling for chat completions(Support Cursor)
2025-12-22 16:47:46 +08:00
moxi
f6a3a1d0ba Remove compat test under translator per review 2025-12-22 16:44:50 +08:00
moxi
830fd8eac2 Fix responses-format handling for chat completions 2025-12-22 13:54:02 +08:00
Luis Pater
a86d501dc2 refactor: replace json.Marshal and json.Unmarshal with sjson and gjson
Optimized the handling of JSON serialization and deserialization by replacing redundant `json.Marshal` and `json.Unmarshal` calls with `sjson` and `gjson`. Introduced a `marshalJSONValue` utility for compact JSON encoding, improving performance and code simplicity. Removed unused `encoding/json` imports.
2025-12-22 11:44:06 +08:00
Evan Nguyen
24e8e20b59 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-21 19:43:24 +07:00
Evan Nguyen
a87f09bad2 feat(antigravity): add session ID generation and mutex for random source 2025-12-21 17:50:41 +07:00
evann
bc6c4cdbfc feat(antigravity): add logging for cached token setting errors in responses 2025-12-19 16:49:50 +07:00
evann
404546ce93 refactor(antigravity): regarding production endpoint caching 2025-12-19 16:36:54 +07:00
evann
6dd1cf1dd6 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-19 16:34:28 +07:00
evann
9058d406a3 feat(antigravity): enhance prompt caching support and update agent version 2025-12-19 16:33:41 +07:00
111 changed files with 7083 additions and 2621 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/* docs/*
README.md README.md
README_CN.md README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE LICENSE
# Runtime data folders (should be mounted as volumes) # Runtime data folders (should be mounted as volumes)
@@ -25,10 +23,14 @@ config.yaml
# Development/editor # Development/editor
bin/* bin/*
.claude/*
.vscode/* .vscode/*
.claude/*
.codex/*
.gemini/* .gemini/*
.serena/* .serena/*
.agent/* .agent/*
.agents/*
.opencode/*
.bmad/* .bmad/*
_bmad/* _bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
--- ---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug** **Describe the bug**
A clear and concise description of what the bug is. A clear and concise description of what the bug is.

11
.gitignore vendored
View File

@@ -12,11 +12,15 @@ bin/*
logs/* logs/*
conv/* conv/*
temp/* temp/*
refs/*
# Storage backends
pgstore/* pgstore/*
gitstore/* gitstore/*
objectstore/* objectstore/*
# Static assets
static/* static/*
refs/*
# Authentication data # Authentication data
auths/* auths/*
@@ -30,12 +34,17 @@ GEMINI.md
# Tooling metadata # Tooling metadata
.vscode/* .vscode/*
.codex/*
.claude/* .claude/*
.gemini/* .gemini/*
.serena/* .serena/*
.agent/* .agent/*
.agents/*
.agents/*
.opencode/*
.bmad/* .bmad/*
_bmad/* _bmad/*
_bmad-output/*
.mcp/cache/ .mcp/cache/
# macOS # macOS

BIN
assets/cubence.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
assets/packycode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

View File

@@ -434,7 +434,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err) log.Errorf("failed to configure log output: %v", err)
return return
} }

View File

@@ -35,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
api-keys: api-keys:
- "your-api-key-1" - "your-api-key-1"
- "your-api-key-2" - "your-api-key-2"
- "your-api-key-3"
# Enable debug logging # Enable debug logging
debug: false debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# Open OAuth URLs in incognito/private browser mode. # Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session. # Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support) # Default: false (but Kiro auth defaults to true for multi-account support)
@@ -71,9 +75,18 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws). # When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false ws-auth: false
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# Gemini API keys # Gemini API keys
# gemini-api-key: # gemini-api-key:
# - api-key: "AIzaSy...01" # - api-key: "AIzaSy...01"
@@ -82,6 +95,9 @@ ws-auth: false
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # proxy-url: "socks5://proxy.example.com:1080"
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models: # excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match) # - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) # - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
@@ -97,6 +113,9 @@ ws-auth: false
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models: # excluded-models:
# - "gpt-5.1" # exclude specific models (exact match) # - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex) # - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -114,7 +133,7 @@ ws-auth: false
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models: # models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name # - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model # alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models: # excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match) # - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219) # - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
@@ -155,9 +174,9 @@ ws-auth: false
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names # models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name # - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias # alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro" # - name: "gemini-2.5-pro"
# alias: "vertex-pro" # alias: "vertex-pro"
# Amp Integration # Amp Integration
@@ -166,6 +185,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com" # upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file) # # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: "" # upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false) # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false # restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false) # # Force model mappings to run before checking local API keys (default: false)
@@ -175,12 +206,42 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5) # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4). # # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings: # model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI # - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead # to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "gpt-5" # - from: "claude-sonnet-4-5-20250929"
# to: "gemini-2.5-pro" # to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-3-opus-20240229" # - from: "claude-haiku-4-5-20251001"
# to: "claude-3-5-sonnet-20241022" # to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# oauth-model-mappings:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# OAuth provider excluded models # OAuth provider excluded models
# oauth-excluded-models: # oauth-excluded-models:

View File

@@ -0,0 +1,538 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -431,9 +431,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path) log.WithError(err).Warnf("failed to stat auth file %s", path)
} }
} }
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry return entry
} }
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string { func authEmail(auth *coreauth.Auth) string {
if auth == nil { if auth == nil {
return "" return ""

View File

@@ -597,11 +597,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr)) filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr { for i := range arr {
entry := arr[i] entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey) normalizeCodexKey(&entry)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if entry.BaseURL == "" { if entry.BaseURL == "" {
continue continue
} }
@@ -613,12 +609,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
} }
func (h *Handler) PatchCodexKey(c *gin.Context) { func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct { type codexKeyPatch struct {
APIKey *string `json:"api-key"` APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"` Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"` BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"` ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"` Models *[]config.CodexModel `json:"models"`
ExcludedModels *[]string `json:"excluded-models"` Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
} }
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
@@ -667,12 +664,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil { if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
} }
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil { if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers) entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
} }
if body.Value.ExcludedModels != nil { if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
} }
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys() h.cfg.SanitizeCodexKeys()
h.persist(c) h.persist(c)
@@ -762,6 +763,32 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized entry.Models = normalized
} }
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
// GetAmpCode returns the complete ampcode configuration. // GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) { func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil { if h == nil || h.cfg == nil {
@@ -913,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) { func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v }) h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
} }
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
} }
} }
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads. // SetConfig updates the in-memory config reference when the server hot-reloads.
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg } func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }

View File

@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"files": files}) c.JSON(http.StatusOK, gin.H{"files": files})
} }
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name. // DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) { func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil { if h == nil {

View File

@@ -1,12 +1,25 @@
package management package management
import ( import (
"encoding/json"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
) )
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot. // GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) { func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount, "failed_requests": snapshot.FailureCount,
}) })
} }
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
} }
return &RequestInfo{ return &RequestInfo{
URL: url, URL: url,
Method: method, Method: method,
Headers: headers, Headers: headers,
Body: body, Body: body,
RequestID: logging.GetGinRequestID(c),
}, nil }, nil
} }

View File

@@ -15,10 +15,11 @@ import (
// RequestInfo holds essential details of an incoming HTTP request for logging purposes. // RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct { type RequestInfo struct {
URL string // URL is the request URL. URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST). Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers. Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body. Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
} }
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data. // ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
@@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.requestInfo.Method, w.requestInfo.Method,
w.requestInfo.Headers, w.requestInfo.Headers,
w.requestInfo.Body, w.requestInfo.Body,
w.requestInfo.RequestID,
) )
if err == nil { if err == nil {
w.streamWriter = streamWriter w.streamWriter = streamWriter
@@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
} }
if loggerWithOptions, ok := w.logger.(interface { if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
}); ok { }); ok {
return loggerWithOptions.LogRequestWithOptions( return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL, w.requestInfo.URL,
@@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiResponseBody, apiResponseBody,
apiResponseErrors, apiResponseErrors,
forceLog, forceLog,
w.requestInfo.RequestID,
) )
} }
@@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiRequestBody, apiRequestBody,
apiResponseBody, apiResponseBody,
apiResponseErrors, apiResponseErrors,
w.requestInfo.RequestID,
) )
} }

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
} }
} }
// Check API key change // Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged { upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil { if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok { if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache() ms.InvalidateCache()
} }
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error { func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil { if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */) // Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey) ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache() ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} }
proxy, err := createReverseProxy(upstreamURL, m.secretSource) proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -279,16 +300,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
return true return true
} }
// Build map for efficient comparison // Build map for efficient and robust comparison
oldMap := make(map[string]string, len(old.ModelMappings)) type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings { for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To) oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
} }
for _, mapping := range new.ModelMappings { for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From) from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To) to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to { if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true return true
} }
} }
@@ -306,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey return oldKey != newKey
} }
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging). // GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper { func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
}) })
} }
} }
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -3,6 +3,7 @@
package amp package amp
import ( import (
"regexp"
"strings" "strings"
"sync" "sync"
@@ -26,13 +27,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage. // DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct { type DefaultModelMapper struct {
mu sync.RWMutex mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys) mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
} }
// NewModelMapper creates a new model mapper with the given initial mappings. // NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{ m := &DefaultModelMapper{
mappings: make(map[string]string), mappings: make(map[string]string),
regexps: nil,
} }
m.UpdateMappings(mappings) m.UpdateMappings(mappings)
return m return m
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
// Check for direct mapping // Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest] targetModel, exists := m.mappings[normalizedRequest]
if !exists { if !exists {
return "" // Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
} }
// Verify target model has available providers // Verify target model has available providers
@@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
// Clear and rebuild mappings // Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings)) m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings { for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From) from := strings.TrimSpace(mapping.From)
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue continue
} }
// Store with normalized lowercase key for case-insensitive lookup if mapping.Regex {
normalizedFrom := strings.ToLower(from) // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
m.mappings[normalizedFrom] = to pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
log.Debugf("amp model mapping registered: %s -> %s", from, to) if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
} }
if len(m.mappings) > 0 { if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings)) log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
} }
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
} }
// GetMappings returns a copy of current mappings (for debugging/status). // GetMappings returns a copy of current mappings (for debugging/status).
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
} }
return result return result
} }
type regexMapping struct {
re *regexp.Regexp
to string
}

View File

@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified") t.Error("Original map was modified")
} }
} }
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix but should match base via regex
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}

View File

@@ -18,6 +18,33 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer. // readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior. // Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct { type readCloser struct {
@@ -48,6 +75,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key // We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization") req.Header.Del("Authorization")
req.Header.Del("X-Api-Key") req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging // Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" { if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
) )
// Helper: compress data with gzip // Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
} }
} }
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) { func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error // Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -1,6 +1,7 @@
package amp package amp
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/http" "net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's // localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting. // localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings") authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
} }
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support // Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) { proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil { if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass) rootMiddleware = append(rootMiddleware, authWithBypass)
} }
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...) engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...) engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil { if auth != nil {
ampProviders.Use(auth) ampProviders.Use(auth)
} }
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider") provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
) )
// SecretSource provides Amp API keys with configurable precedence and caching // SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) { func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil return s.key, nil
} }
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
) )
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3) t.Fatalf("after cache expiry, expected new-value, got %q", got3)
} }
} }
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory. // Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger var requestLogger logging.RequestLogger
var toggle func(bool) var toggle func(bool)
if optionState.requestLoggerFactory != nil { if !cfg.CommercialMode {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath) if optionState.requestLoggerFactory != nil {
} requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
if requestLogger != nil { }
engine.Use(middleware.RequestLoggingMiddleware(requestLogger)) if requestLogger != nil {
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok { engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
toggle = setter.SetEnabled if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
} }
} }
@@ -494,6 +496,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware()) mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{ {
mgmt.GET("/usage", s.mgmt.GetUsageStatistics) mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig) mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML) mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML) mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -516,6 +520,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL) mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL) mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject) mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -538,6 +544,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.DELETE("/logs", s.mgmt.DeleteLogs) mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs) mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog) mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog) mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog) mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog) mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
@@ -564,6 +571,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -866,7 +877,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
} }
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err) log.Errorf("failed to reconfigure log output: %v", err)
} else { } else {
if oldCfg == nil { if oldCfg == nil {

View File

@@ -40,6 +40,10 @@ type KiroTokenData struct {
ClientSecret string `json:"clientSecret,omitempty"` ClientSecret string `json:"clientSecret,omitempty"`
// Email is the user's email address (used for file naming) // Email is the user's email address (used for file naming)
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
StartURL string `json:"startUrl,omitempty"`
// Region is the AWS region for IDC authentication (only for IDC auth method)
Region string `json:"region,omitempty"`
} }
// KiroAuthBundle aggregates authentication data after OAuth flow completion // KiroAuthBundle aggregates authentication data after OAuth flow completion

View File

@@ -2,16 +2,19 @@
package kiro package kiro
import ( import (
"bufio"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"html" "html"
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
@@ -24,19 +27,31 @@ import (
const ( const (
// AWS SSO OIDC endpoints // AWS SSO OIDC endpoints
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
// Kiro's start URL for Builder ID // Kiro's start URL for Builder ID
builderIDStartURL = "https://view.awsapps.com/start" builderIDStartURL = "https://view.awsapps.com/start"
// Default region for IDC
defaultIDCRegion = "us-east-1"
// Polling interval // Polling interval
pollInterval = 5 * time.Second pollInterval = 5 * time.Second
// Authorization code flow callback // Authorization code flow callback
authCodeCallbackPath = "/oauth/callback" authCodeCallbackPath = "/oauth/callback"
authCodeCallbackPort = 19877 authCodeCallbackPort = 19877
// User-Agent to match official Kiro IDE // User-Agent to match official Kiro IDE
kiroUserAgent = "KiroIDE" kiroUserAgent = "KiroIDE"
// IDC token refresh headers (matching Kiro IDE behavior)
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
)
// Sentinel errors for OIDC token polling
var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrSlowDown = errors.New("slow_down")
) )
// SSOOIDCClient handles AWS SSO OIDC authentication. // SSOOIDCClient handles AWS SSO OIDC authentication.
@@ -83,6 +98,440 @@ type CreateTokenResponse struct {
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken"`
} }
// getOIDCEndpoint returns the OIDC endpoint for the given region.
func getOIDCEndpoint(region string) string {
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
// promptInput prompts the user for input with an optional default value.
func promptInput(prompt, defaultValue string) string {
reader := bufio.NewReader(os.Stdin)
if defaultValue != "" {
fmt.Printf("%s [%s]: ", prompt, defaultValue)
} else {
fmt.Printf("%s: ", prompt)
}
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return defaultValue
}
input = strings.TrimSpace(input)
if input == "" {
return defaultValue
}
return input
}
// promptSelect prompts the user to select from options using number input.
func promptSelect(prompt string, options []string) int {
reader := bufio.NewReader(os.Stdin)
for {
fmt.Println(prompt)
for i, opt := range options {
fmt.Printf(" %d) %s\n", i+1, opt)
}
fmt.Printf("Enter selection (1-%d): ", len(options))
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return 0 // Default to first option on error
}
input = strings.TrimSpace(input)
// Parse the selection
var selection int
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
continue
}
return selection - 1
}
}
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]interface{}{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
}
var result RegisterClientResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"startUrl": startURL,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
}
var result StartDeviceAuthResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"deviceCode": deviceCode,
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Check for pending authorization
if resp.StatusCode == http.StatusBadRequest {
var errResp struct {
Error string `json:"error"`
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, ErrAuthorizationPending
}
if errResp.Error == "slow_down" {
return nil, ErrSlowDown
}
}
log.Debugf("create token failed: %s", string(respBody))
return nil, fmt.Errorf("create token failed")
}
if resp.StatusCode != http.StatusOK {
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
// Set headers matching kiro2api's IDC token refresh
// These headers are required for successful IDC token refresh
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
req.Header.Set("Connection", "keep-alive")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "*")
req.Header.Set("sec-fetch-mode", "cors")
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}, nil
}
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Register client with the specified region
fmt.Println("\nRegistering client...")
regResp, err := c.RegisterClientWithRegion(ctx, region)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
log.Debugf("Client registered: %s", regResp.ClientID)
// Step 2: Start device authorization with IDC start URL
fmt.Println("Starting device authorization...")
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
if err != nil {
return nil, fmt.Errorf("failed to start device auth: %w", err)
}
// Step 3: Show user the verification URL
fmt.Printf("\n")
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf(" Confirm the following code in the browser:\n")
fmt.Printf(" Code: %s\n", authResp.UserCode)
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
// Set incognito mode based on config
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
if !c.cfg.IncognitoBrowser {
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
} else {
log.Debug("kiro: using incognito mode for multi-account support")
}
} else {
browser.SetIncognitoMode(true)
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Open browser
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" Please open the URL manually in your browser.")
} else {
fmt.Println(" (Browser opened automatically)")
}
// Step 4: Poll for token
fmt.Println("Waiting for authorization...")
interval := pollInterval
if authResp.Interval > 0 {
interval = time.Duration(authResp.Interval) * time.Second
}
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
browser.CloseBrowser()
return nil, ctx.Err()
case <-time.After(interval):
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
if err != nil {
if errors.Is(err, ErrAuthorizationPending) {
fmt.Print(".")
continue
}
if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second
continue
}
browser.CloseBrowser()
return nil, fmt.Errorf("token creation failed: %w", err)
}
fmt.Println("\n\n✓ Authorization successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
StartURL: startURL,
Region: region,
}, nil
}
}
// Close browser on timeout
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Prompt for login method
options := []string{
"Use with Builder ID (personal AWS account)",
"Use with IDC Account (organization SSO)",
}
selection := promptSelect("\n? Select login method:", options)
if selection == 0 {
// Builder ID flow - use existing implementation
return c.LoginWithBuilderID(ctx)
}
// IDC flow - prompt for start URL and region
fmt.Println()
startURL := promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
}
region := promptInput("? Enter Region", defaultIDCRegion)
return c.LoginWithIDC(ctx, startURL, region)
}
// RegisterClient registers a new OIDC client with AWS. // RegisterClient registers a new OIDC client with AWS.
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
payload := map[string]interface{}{ payload := map[string]interface{}{
@@ -211,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
} }
if json.Unmarshal(respBody, &errResp) == nil { if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" { if errResp.Error == "authorization_pending" {
return nil, fmt.Errorf("authorization_pending") return nil, ErrAuthorizationPending
} }
if errResp.Error == "slow_down" { if errResp.Error == "slow_down" {
return nil, fmt.Errorf("slow_down") return nil, ErrSlowDown
} }
} }
log.Debugf("create token failed: %s", string(respBody)) log.Debugf("create token failed: %s", string(respBody))
@@ -359,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
case <-time.After(interval): case <-time.After(interval):
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if err != nil { if err != nil {
errStr := err.Error() if errors.Is(err, ErrAuthorizationPending) {
if strings.Contains(errStr, "authorization_pending") {
fmt.Print(".") fmt.Print(".")
continue continue
} }
if strings.Contains(errStr, "slow_down") { if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second interval += 5 * time.Second
continue continue
} }

View File

@@ -39,6 +39,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features. // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"` Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout. // LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"` LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -60,6 +63,9 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded. // QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// Routing controls credential selection behavior.
Routing RoutingConfig `yaml:"routing" json:"routing"`
// WebsocketAuth enables or disables authentication for the WebSocket API. // WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"` WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
@@ -92,6 +98,14 @@ type Config struct {
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries. // OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"` OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
// Payload defines default and override rules for provider payload parameters. // Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"` Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -136,6 +150,20 @@ type QuotaExceeded struct {
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
} }
// RoutingConfig configures how credentials are selected for requests.
type RoutingConfig struct {
// Strategy selects the credential selection strategy.
// Supported values: "round-robin" (default), "fill-first".
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID rename mapping for a specific channel.
// It maps the original model name (Name) to the client-visible alias (Alias).
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests. // AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping // When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available. // allows routing to an alternative model that IS available.
@@ -146,6 +174,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4"). // To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry. // The target model must have available providers in the registry.
To string `yaml:"to" json:"to"` To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
} }
// AmpCode groups Amp CLI integration settings including upstream routing, // AmpCode groups Amp CLI integration settings including upstream routing,
@@ -157,6 +190,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls. // UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"` UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.) // RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by // to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient). // browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -172,6 +210,17 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"` ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
} }
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads. // PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct { type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload. // Default defines rules that only set parameters when they are missing in the payload.
@@ -231,6 +280,9 @@ type ClaudeModel struct {
Alias string `yaml:"alias" json:"alias"` Alias string `yaml:"alias" json:"alias"`
} }
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key, // CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint. // including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct { type CodexKey struct {
@@ -247,6 +299,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided. // ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"` ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key. // Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -254,6 +309,18 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
} }
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key, // GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers. // including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct { type GeminiKey struct {
@@ -269,6 +336,9 @@ type GeminiKey struct {
// ProxyURL optionally overrides the global proxy for this API key. // ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key. // Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -276,6 +346,18 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
} }
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. // KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
type KiroKey struct { type KiroKey struct {
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
@@ -391,7 +473,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.DisableCooling = false cfg.DisableCooling = false
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil { if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional { if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error. // In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
@@ -460,6 +542,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map. // Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
if cfg.legacyMigrationPending { if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...") fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" { if !optional && configFile != "" {
@@ -476,6 +561,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil return &cfg, nil
} }
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// and ensures (From, To) pairs are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
continue
}
seenName := make(map[string]struct{}, len(mappings))
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
aliasKey := strings.ToLower(alias)
if _, ok := seenName[nameKey]; ok {
continue
}
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are // SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before // not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries. // evaluation and preserves the relative order of remaining entries.
@@ -861,8 +990,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
} }
// mergeMappingPreserve merges keys from src into dst mapping node while preserving // mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended // key order and comments of existing keys in dst. New keys are only added if their
// to dst at the end, copying their node structure from src. // value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) { func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil { if dst == nil || src == nil {
return return
@@ -873,20 +1002,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src) copyNodeShallow(dst, src)
return return
} }
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 { for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i] sk := src.Content[i]
sv := src.Content[i+1] sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value) idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 { if idx >= 0 {
// Merge into existing value node // Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1] dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv) mergeNodePreserve(dv, sv)
} else { } else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) { // New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue continue
} }
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv)) dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
} }
} }
@@ -969,32 +1097,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1 return -1
} }
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { // isZeroValueNode returns true if the YAML node represents a zero/default value
switch key { // that should not be written as a new key to preserve config cleanliness.
case "generative-language-api-key", // For mappings and sequences, recursively checks if all children are zero values.
"gemini-api-key", func isZeroValueNode(node *yaml.Node) bool {
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
if node == nil { if node == nil {
return true return true
} }
switch node.Kind { switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode: case yaml.ScalarNode:
return node.Tag == "!!null" switch node.Tag {
default: case "!!bool":
return false return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
} }
return false
} }
// deepCopyNode creates a deep copy of a yaml.Node graph. // deepCopyNode creates a deep copy of a yaml.Node graph.

View File

@@ -22,6 +22,21 @@ type SDKConfig struct {
// Access holds request authentication provider configuration. // Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
}
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
} }
// AccessConfig groups request authentication providers. // AccessConfig groups request authentication providers.

View File

@@ -42,6 +42,9 @@ type VertexCompatModel struct {
Alias string `yaml:"alias" json:"alias"` Alias string `yaml:"alias" json:"alias"`
} }
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. // SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() { func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil { if cfg == nil {

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -14,11 +15,24 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}
const skipGinLogKey = "__gin_skip_request_logging__" const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses // GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency, // using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format. // client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
// //
// Returns: // Returns:
// - gin.HandlerFunc: A middleware handler for request logging // - gin.HandlerFunc: A middleware handler for request logging
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
path := c.Request.URL.Path path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery) raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}
c.Next() c.Next()
if shouldSkipGinRequestLogging(c) { if shouldSkipGinRequestLogging(c) {
@@ -49,23 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
clientIP := c.ClientIP() clientIP := c.ClientIP()
method := c.Request.Method method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String() errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path) if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" { if errorMessage != "" {
logLine = logLine + " | " + errorMessage logLine = logLine + " | " + errorMessage
} }
entry := log.WithField("request_id", requestID)
switch { switch {
case statusCode >= http.StatusInternalServerError: case statusCode >= http.StatusInternalServerError:
log.Error(logLine) entry.Error(logLine)
case statusCode >= http.StatusBadRequest: case statusCode >= http.StatusBadRequest:
log.Warn(logLine) entry.Warn(logLine)
default: default:
log.Info(logLine) entry.Info(logLine)
} }
} }
} }
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs // GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
// them using logrus. When a panic occurs, it captures the panic value, stack trace, // them using logrus. When a panic occurs, it captures the panic value, stack trace,
// and request path, then returns a 500 Internal Server Error response to the client. // and request path, then returns a 500 Internal Server Error response to the client.

View File

@@ -10,6 +10,7 @@ import (
"sync" "sync"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@@ -24,7 +25,8 @@ var (
) )
// LogFormatter defines a custom log format for logrus. // LogFormatter defines a custom log format for logrus.
// This formatter adds timestamp, level, and source location to each log entry. // This formatter adds timestamp, level, request ID, and source location to each log entry.
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{} type LogFormatter struct{}
// Format renders a single log entry with custom formatting. // Format renders a single log entry with custom formatting.
@@ -38,16 +40,24 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05") timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n") message := strings.TrimRight(entry.Message, "\r\n")
// Handle nil Caller (can happen with some log entries) reqID := "--------"
callerFile := "unknown" if id, ok := entry.Data["request_id"].(string); ok && id != "" {
callerLine := 0 reqID = id
if entry.Caller != nil { }
callerFile = filepath.Base(entry.Caller.File)
callerLine = entry.Caller.Line level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", level)
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
} else {
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
} }
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
buffer.WriteString(formatted) buffer.WriteString(formatted)
return buffer.Bytes(), nil return buffer.Bytes(), nil
@@ -75,10 +85,30 @@ func SetupBaseLogger() {
}) })
} }
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout. // ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory // When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit. // until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger() SetupBaseLogger()
writerMu.Lock() writerMu.Lock()
@@ -87,10 +117,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
logDir := "logs" logDir := "logs"
if base := util.WritablePath(); base != "" { if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs") logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
} }
protectedPath := "" protectedPath := ""
if loggingToFile { if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil { if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err) return fmt.Errorf("logging: failed to create log directory: %w", err)
} }
@@ -114,7 +146,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
} }
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath) configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil return nil
} }

View File

@@ -43,10 +43,11 @@ type RequestLogger interface {
// - response: The raw response data // - response: The raw response data
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
// //
@@ -55,11 +56,12 @@ type RequestLogger interface {
// - method: The HTTP method // - method: The HTTP method
// - headers: The request headers // - headers: The request headers
// - body: The request body // - body: The request body
// - requestID: Optional request ID for log file naming
// //
// Returns: // Returns:
// - StreamingLogWriter: A writer for streaming response chunks // - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise // - error: An error if logging initialization fails, nil otherwise
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
// IsEnabled returns whether request logging is currently enabled. // IsEnabled returns whether request logging is currently enabled.
// //
@@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
// - response: The raw response data // - response: The raw response data
// - apiRequest: The API request data // - apiRequest: The API request data
// - apiResponse: The API response data // - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// //
// Returns: // Returns:
// - error: An error if logging fails, nil otherwise // - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error { func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
} }
// LogRequestWithOptions logs a request with optional forced logging behavior. // LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled. // The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error { func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force) return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
} }
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error { func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
if !l.enabled && !force { if !l.enabled && !force {
return nil return nil
} }
@@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
return fmt.Errorf("failed to create logs directory: %w", errEnsure) return fmt.Errorf("failed to create logs directory: %w", errEnsure)
} }
// Generate filename // Generate filename with request ID
filename := l.generateFilename(url) filename := l.generateFilename(url, requestID)
if force && !l.enabled { if force && !l.enabled {
filename = l.generateErrorFilename(url) filename = l.generateErrorFilename(url, requestID)
} }
filePath := filepath.Join(l.logsDir, filename) filePath := filepath.Join(l.logsDir, filename)
@@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
// - method: The HTTP method // - method: The HTTP method
// - headers: The request headers // - headers: The request headers
// - body: The request body // - body: The request body
// - requestID: Optional request ID for log file naming
// //
// Returns: // Returns:
// - StreamingLogWriter: A writer for streaming response chunks // - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise // - error: An error if logging initialization fails, nil otherwise
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) { func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
if !l.enabled { if !l.enabled {
return &NoOpStreamingLogWriter{}, nil return &NoOpStreamingLogWriter{}, nil
} }
@@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
return nil, fmt.Errorf("failed to create logs directory: %w", err) return nil, fmt.Errorf("failed to create logs directory: %w", err)
} }
// Generate filename // Generate filename with request ID
filename := l.generateFilename(url) filename := l.generateFilename(url, requestID)
filePath := filepath.Join(l.logsDir, filename) filePath := filepath.Join(l.logsDir, filename)
requestHeaders := make(map[string][]string, len(headers)) requestHeaders := make(map[string][]string, len(headers))
@@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
} }
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs. // generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string) string { func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url)) return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
} }
// ensureLogsDir creates the logs directory if it doesn't exist. // ensureLogsDir creates the logs directory if it doesn't exist.
@@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
} }
// generateFilename creates a sanitized filename from the URL path and current timestamp. // generateFilename creates a sanitized filename from the URL path and current timestamp.
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
// //
// Parameters: // Parameters:
// - url: The request URL // - url: The request URL
// - requestID: Optional request ID to include in filename
// //
// Returns: // Returns:
// - string: A sanitized filename for the log file // - string: A sanitized filename for the log file
func (l *FileRequestLogger) generateFilename(url string) string { func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
// Extract path from URL // Extract path from URL
path := url path := url
if strings.Contains(url, "?") { if strings.Contains(url, "?") {
@@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
sanitized := l.sanitizeForFilename(path) sanitized := l.sanitizeForFilename(path)
// Add timestamp // Add timestamp
timestamp := time.Now().Format("2006-01-02T150405-.000000000") timestamp := time.Now().Format("2006-01-02T150405")
timestamp = strings.Replace(timestamp, ".", "", -1)
id := requestLogID.Add(1) // Use request ID if provided, otherwise use sequential ID
var idPart string
if len(requestID) > 0 && requestID[0] != "" {
idPart = requestID[0]
} else {
id := requestLogID.Add(1)
idPart = fmt.Sprintf("%d", id)
}
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id) return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
} }
// sanitizeForFilename replaces characters that are not safe for filenames. // sanitizeForFilename replaces characters that are not safe for filenames.

View File

@@ -0,0 +1,61 @@
package logging
import (
"context"
"crypto/rand"
"encoding/hex"
"github.com/gin-gonic/gin"
)
// requestIDKey is the context key for storing/retrieving request IDs.
type requestIDKey struct{}
// ginRequestIDKey is the Gin context key for request IDs.
const ginRequestIDKey = "__request_id__"
// GenerateRequestID creates a new 8-character hex request ID.
func GenerateRequestID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return "00000000"
}
return hex.EncodeToString(b)
}
// WithRequestID returns a new context with the request ID attached.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// GetRequestID retrieves the request ID from the context.
// Returns empty string if not found.
func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// SetGinRequestID stores the request ID in the Gin context.
func SetGinRequestID(c *gin.Context, requestID string) {
if c != nil {
c.Set(ginRequestIDKey, requestID)
}
}
// GetGinRequestID retrieves the request ID from the Gin context.
func GetGinRequestID(c *gin.Context) string {
if c == nil {
return ""
}
if id, exists := c.Get(ginRequestIDKey); exists {
if s, ok := id.(string); ok {
return s
}
}
return ""
}

View File

@@ -24,10 +24,11 @@ import (
) )
const ( const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest" defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
managementAssetName = "management.html" defaultManagementFallbackURL = "https://cpamc.router-for.me/"
httpUserAgent = "CLIProxyAPI-management-updater" managementAssetName = "management.html"
updateCheckInterval = 3 * time.Hour httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
) )
// ManagementFileName exposes the control panel asset filename. // ManagementFileName exposes the control panel asset filename.
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
return return
} }
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours // Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock() lastUpdateCheckMu.Lock()
now := time.Now() now := time.Now()
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
lastUpdateCheckTime = now lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock() lastUpdateCheckMu.Unlock()
if err := os.MkdirAll(staticDir, 0o755); err != nil { if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(err).Warn("failed to prepare static directory for management asset") log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return return
} }
releaseURL := resolveReleaseURL(panelRepository) releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL) client := newHTTPClient(proxyURL)
localPath := filepath.Join(staticDir, managementAssetName)
localHash, err := fileSHA256(localPath) localHash, err := fileSHA256(localPath)
if err != nil { if err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL) asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil { if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information") log.WithError(err).Warn("failed to fetch latest management release information")
return return
} }
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL) data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil { if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset") log.WithError(err).Warn("failed to download management asset")
return return
} }
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
log.Infof("management asset updated successfully (hash=%s)", downloadedHash) log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
} }
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string { func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo) repo = strings.TrimSpace(repo)
if repo == "" { if repo == "" {

View File

@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400}, {ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400}, {ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport}, {ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000}, {ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200}, {ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000}, {ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
@@ -740,6 +741,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600}, {ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600}, {ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000}, {ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
} }
models := make([]*ModelInfo, 0, len(entries)) models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries { for _, entry := range entries {
@@ -780,6 +782,32 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
} }
} }
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
}
}
}
return nil
}
// GetGitHubCopilotModels returns the available models for GitHub Copilot. // GetGitHubCopilotModels returns the available models for GitHub Copilot.
// These models are available through the GitHub Copilot API at api.githubcopilot.com. // These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo { func GetGitHubCopilotModels() []*ModelInfo {

View File

@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
if err != nil { if err != nil {
return resp, err return resp, err
} }
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,

View File

@@ -33,18 +33,18 @@ import (
) )
const ( const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityCountTokensPath = "/v1internal:countTokens" antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent" antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent" antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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" antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
antigravityAuthType = "antigravity" antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second refreshSkew = 3000 * time.Second
) )
var ( var (
@@ -76,7 +76,8 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API. // Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if strings.Contains(req.Model, "claude") { isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if isClaude {
return e.executeClaudeNonStream(ctx, auth, req, opts) return e.executeClaudeNonStream(ctx, auth, req, opts)
} }
@@ -98,7 +99,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -193,7 +194,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -520,6 +521,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
@@ -527,7 +530,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
@@ -676,6 +679,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -694,7 +699,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload) payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(req.Model, payload) payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")
@@ -1156,7 +1161,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
} }
return []string{ return []string{
antigravityBaseURLDaily, antigravityBaseURLDaily,
// antigravityBaseURLAutopush, antigravitySandboxBaseURLDaily,
antigravityBaseURLProd, antigravityBaseURLProd,
} }
} }
@@ -1308,7 +1313,7 @@ func alias2ModelName(modelName string) string {
// normalizeAntigravityThinking clamps or removes thinking config based on model support. // normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens. // For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte { func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload) payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) { if !util.ModelSupportsThinking(model) {
return payload return payload
@@ -1320,7 +1325,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
raw := int(budget.Int()) raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw) normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude { if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload) effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax { if effectiveMax > 0 && normalized >= effectiveMax {

View File

@@ -49,33 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", model)
if upstreamModel == "" {
upstreamModel = req.Model
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -167,26 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) model := req.Model
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
if upstreamModel == "" { model = override
upstreamModel = req.Model
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -310,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) model := req.Model
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
if upstreamModel == "" { model = override
upstreamModel = req.Model
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }
@@ -461,6 +446,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
return util.ApplyClaudeThinkingConfig(body, budget) return util.ApplyClaudeThinkingConfig(body, budget)
} }
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
}
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled. // ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error. // Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized. // This function should be called after all thinking configuration is finalized.
@@ -662,7 +660,14 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
} }
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) { func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
r.Header.Set("Authorization", "Bearer "+apiKey) useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
r.Header.Del("Authorization")
r.Header.Set("x-api-key", apiKey)
} else {
r.Header.Set("Authorization", "Bearer "+apiKey)
}
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
var ginHeaders http.Header var ginHeaders http.Header

View File

@@ -49,18 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -146,20 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses" url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -246,20 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false) body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting) enc, err := tokenizerForCodexModel(model)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
} }
@@ -520,3 +527,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
} }
return return
} }
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
stream = out stream = out
go func(resp *http.Response, reqBody []byte, attempt string) { go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out) defer close(out)
defer func() { defer func() {
if errClose := resp.Body.Close(); errClose != nil { if errClose := resp.Body.Close(); errClose != nil {
@@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter.publish(ctx, detail) reporter.publish(ctx, detail)
} }
if bytes.HasPrefix(line, dataTag) { if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
} }
} }
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
@@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data)) reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param) segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
@@ -417,6 +417,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int var lastStatus int
var lastBody []byte var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models { for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
@@ -425,7 +427,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload) payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token() tok, errTok := tokenSource.Token()
if errTok != nil { if errTok != nil {

View File

@@ -77,19 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
// Official Gemini API via API key or OAuth bearer // Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -98,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
} }
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -173,21 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -287,19 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth) apiKey, bearer := geminiCreds(auth)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq) requestBody := bytes.NewReader(translatedReq)
@@ -398,6 +410,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
return base return base
} }
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveGeminiConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) { func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
var attrs map[string]string var attrs map[string]string
if auth != nil { if auth != nil {

View File

@@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
} }
} }
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
@@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials. // countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
} }
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials. // countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
} }
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
} }
return tok.AccessToken, nil return tok.AccessToken, nil
} }
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -58,15 +58,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -150,15 +148,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour. // Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools") toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 { if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
@@ -445,20 +441,98 @@ func ensureToolsArray(body []byte) []byte {
return updated return updated
} }
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking. // preserveReasoningContentInMessages ensures reasoning_content from assistant messages in the
// This should be called after NormalizeThinkingConfig has processed the payload. // conversation history is preserved when sending to iFlow models that support thinking.
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking. // This is critical for multi-turn conversations where the model needs to see its previous
func applyIFlowThinkingConfig(body []byte) []byte { // reasoning to maintain coherent thought chains across tool calls and conversation turns.
effort := gjson.GetBytes(body, "reasoning_effort") //
if !effort.Exists() { // For GLM-4.7 and MiniMax-M2.1, the full assistant response (including reasoning) must be
// appended back into message history before the next call.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4.7") ||
strings.HasPrefix(model, "glm-4-7") ||
strings.HasPrefix(model, "minimax-m2.1") ||
strings.HasPrefix(model, "minimax-m2-1")
if !needsPreservation {
return body return body
} }
val := strings.ToLower(strings.TrimSpace(effort.String())) messages := gjson.GetBytes(body, "messages")
enableThinking := val != "none" && val != "" if !messages.Exists() || !messages.IsArray() {
return body
}
body, _ = sjson.DeleteBytes(body, "reasoning_effort") // Check if any assistant message already has reasoning_content preserved
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking) hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
//
// Model-specific handling:
// - GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
// - MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
// - Other iFlow models: Uses chat_template_kwargs.enable_thinking (boolean)
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Check if thinking should be enabled
val := ""
if effort.Exists() {
val = strings.ToLower(strings.TrimSpace(effort.String()))
}
enableThinking := effort.Exists() && val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
if effort.Exists() {
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
}
// GLM-4.7: Use extra_body with thinking config and clear_thinking: false
if strings.HasPrefix(model, "glm-4.7") || strings.HasPrefix(model, "glm-4-7") {
if enableThinking {
body, _ = sjson.SetBytes(body, "extra_body.thinking.type", "enabled")
body, _ = sjson.SetBytes(body, "extra_body.clear_thinking", false)
}
return body
}
// MiniMax-M2.1: Use reasoning_split=true for interleaved thinking
if strings.HasPrefix(model, "minimax-m2.1") || strings.HasPrefix(model, "minimax-m2-1") {
if enableThinking {
body, _ = sjson.SetBytes(body, "reasoning_split", true)
}
return body
}
// Other iFlow models (including GLM-4.6): Use chat_template_kwargs.enable_thinking
if effort.Exists() {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
}
return body return body
} }

View File

@@ -10,6 +10,8 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -43,10 +45,15 @@ const (
// Event Stream error type constants // Event Stream error type constants
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
// kiroUserAgent matches amq2api format for User-Agent header // kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style)
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style)
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
// Kiro IDE style headers (from kiro2api - for IDC auth)
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
kiroIDEAgentModeSpec = "spec"
) )
// Real-time usage estimation configuration // Real-time usage estimation configuration
@@ -101,11 +108,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. // getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes. // Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin.
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
if auth == nil { if auth == nil {
return kiroEndpointConfigs return kiroEndpointConfigs
} }
// For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth)
// Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
if auth.Metadata != nil {
authMethod, _ := auth.Metadata["auth_method"].(string)
if authMethod == "idc" {
log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint")
return kiroEndpointConfigs
}
}
// Check for preference // Check for preference
var preference string var preference string
if auth.Metadata != nil { if auth.Metadata != nil {
@@ -162,6 +182,15 @@ type KiroExecutor struct {
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
} }
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
func isIDCAuth(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Metadata == nil {
return false
}
authMethod, _ := auth.Metadata["auth_method"].(string)
return authMethod == "idc"
}
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format. // buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
// This is critical because OpenAI and Claude formats have different tool structures: // This is critical because OpenAI and Claude formats have different tool structures:
// - OpenAI: tools[].function.name, tools[].function.description // - OpenAI: tools[].function.name, tools[].function.description
@@ -210,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil { } else if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before request") log.Infof("kiro: token refreshed successfully before request")
} }
@@ -262,15 +295,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
} }
httpReq.Header.Set("Content-Type", kiroContentType) httpReq.Header.Set("Content-Type", kiroContentType)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("Accept", kiroAcceptStream) httpReq.Header.Set("Accept", kiroAcceptStream)
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) // Use different headers based on auth type
// IDC auth uses Kiro IDE style headers (from kiro2api)
// Other auth types use Amazon Q CLI style headers
if isIDCAuth(auth) {
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
} else {
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string var attrs map[string]string
if auth != nil { if auth != nil {
attrs = auth.Attributes attrs = auth.Attributes
@@ -358,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
if refreshedAuth != nil { if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed // Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
@@ -416,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
} }
if refreshedAuth != nil { if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed for 403, retrying request") log.Infof("kiro: token refreshed for 403, retrying request")
@@ -518,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil { } else if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before stream request") log.Infof("kiro: token refreshed successfully before stream request")
} }
@@ -568,15 +628,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
} }
httpReq.Header.Set("Content-Type", kiroContentType) httpReq.Header.Set("Content-Type", kiroContentType)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("Accept", kiroAcceptStream) httpReq.Header.Set("Accept", kiroAcceptStream)
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) // Use different headers based on auth type
// IDC auth uses Kiro IDE style headers (from kiro2api)
// Other auth types use Amazon Q CLI style headers
if isIDCAuth(auth) {
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
} else {
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string var attrs map[string]string
if auth != nil { if auth != nil {
attrs = auth.Attributes attrs = auth.Attributes
@@ -677,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
if refreshedAuth != nil { if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed // Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
@@ -735,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
} }
if refreshedAuth != nil { if refreshedAuth != nil {
auth = refreshedAuth auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth) accessToken, profileArn = kiroCredentials(auth)
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed for 403, retrying stream request") log.Infof("kiro: token refreshed for 403, retrying stream request")
@@ -1001,12 +1084,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
// This consolidates the auth_method check that was previously done separately. // This consolidates the auth_method check that was previously done separately.
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
if auth != nil && auth.Metadata != nil { if auth != nil && auth.Metadata != nil {
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
// builder-id auth doesn't need profileArn // builder-id and idc auth don't need profileArn
return "" return ""
} }
} }
// For non-builder-id auth (social auth), profileArn is required // For non-builder-id/idc auth (social auth), profileArn is required
if profileArn == "" { if profileArn == "" {
log.Warnf("kiro: profile ARN not found in auth, API calls may fail") log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
} }
@@ -3010,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
var refreshToken string var refreshToken string
var clientID, clientSecret string var clientID, clientSecret string
var authMethod string var authMethod string
var region, startURL string
if auth.Metadata != nil { if auth.Metadata != nil {
if rt, ok := auth.Metadata["refresh_token"].(string); ok { if rt, ok := auth.Metadata["refresh_token"].(string); ok {
@@ -3024,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
if am, ok := auth.Metadata["auth_method"].(string); ok { if am, ok := auth.Metadata["auth_method"].(string); ok {
authMethod = am authMethod = am
} }
if r, ok := auth.Metadata["region"].(string); ok {
region = r
}
if su, ok := auth.Metadata["start_url"].(string); ok {
startURL = su
}
} }
if refreshToken == "" { if refreshToken == "" {
@@ -3033,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
var tokenData *kiroauth.KiroTokenData var tokenData *kiroauth.KiroTokenData
var err error var err error
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
switch {
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
// IDC refresh with region-specific endpoint
log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region)
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
// Builder ID refresh with default endpoint
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
} else { default:
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
oauth := kiroauth.NewKiroOAuth(e.cfg) oauth := kiroauth.NewKiroOAuth(e.cfg)
tokenData, err = oauth.RefreshToken(ctx, refreshToken) tokenData, err = oauth.RefreshToken(ctx, refreshToken)
@@ -3094,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
return updated, nil return updated, nil
} }
// persistRefreshedAuth persists a refreshed auth record to disk.
// This ensures token refreshes from inline retry are saved to the auth file.
func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
if auth == nil || auth.Metadata == nil {
return fmt.Errorf("kiro executor: cannot persist nil auth or metadata")
}
// Determine the file path from auth attributes or filename
var authPath string
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
authPath = p
}
}
if authPath == "" {
fileName := strings.TrimSpace(auth.FileName)
if fileName == "" {
return fmt.Errorf("kiro executor: auth has no file path or filename")
}
if filepath.IsAbs(fileName) {
authPath = fileName
} else if e.cfg != nil && e.cfg.AuthDir != "" {
authPath = filepath.Join(e.cfg.AuthDir, fileName)
} else {
return fmt.Errorf("kiro executor: cannot determine auth file path")
}
}
// Marshal metadata to JSON
raw, err := json.Marshal(auth.Metadata)
if err != nil {
return fmt.Errorf("kiro executor: marshal metadata failed: %w", err)
}
// Write to temp file first, then rename (atomic write)
tmp := authPath + ".tmp"
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
return fmt.Errorf("kiro executor: write temp auth file failed: %w", err)
}
if err := os.Rename(tmp, authPath); err != nil {
return fmt.Errorf("kiro executor: rename auth file failed: %w", err)
}
log.Debugf("kiro executor: persisted refreshed auth to %s", authPath)
return nil
}
// isTokenExpired checks if a JWT access token has expired. // isTokenExpired checks if a JWT access token has expired.
// Returns true if the token is expired or cannot be parsed. // Returns true if the token is expired or cannot be parsed.
func (e *KiroExecutor) isTokenExpired(accessToken string) bool { func (e *KiroExecutor) isTokenExpired(accessToken string) bool {

View File

@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }

View File

@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
toolsResult := gjson.GetBytes(body, "tools") toolsResult := gjson.GetBytes(body, "tools")

View File

@@ -19,7 +19,7 @@ type usageReporter struct {
provider string provider string
model string model string
authID string authID string
authIndex uint64 authIndex string
apiKey string apiKey string
source string source string
requestedAt time.Time requestedAt time.Time
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true return detail, true
} }
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
CachedTokens: node.Get("cachedContentTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail { func parseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data) usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata") node := usageNode.Get("response.usageMetadata")
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
if !node.Exists() { if !node.Exists() {
return usage.Detail{} return usage.Detail{}
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node)
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
} }
func parseGeminiUsage(data []byte) usage.Detail { func parseGeminiUsage(data []byte) usage.Detail {
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
if !node.Exists() { if !node.Exists() {
return usage.Detail{} return usage.Detail{}
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node)
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
} }
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) { func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() { if !node.Exists() {
return usage.Detail{}, false return usage.Detail{}, false
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node), true
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
} }
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) { func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() { if !node.Exists() {
return usage.Detail{}, false return usage.Detail{}, false
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node), true
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
} }
func parseAntigravityUsage(data []byte) usage.Detail { func parseAntigravityUsage(data []byte) usage.Detail {
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
if !node.Exists() { if !node.Exists() {
return usage.Detail{} return usage.Detail{}
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node)
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
} }
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) { func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() { if !node.Exists() {
return usage.Detail{}, false return usage.Detail{}, false
} }
detail := usage.Detail{ return parseGeminiFamilyUsageDetail(node), true
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
} }
var stopChunkWithoutUsage sync.Map var stopChunkWithoutUsage sync.Map
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
cleaned := jsonBytes cleaned := jsonBytes
var changed bool var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() { if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata") cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true changed = true
} }
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() { if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata") cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true changed = true
} }

View File

@@ -86,6 +86,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
hasSystemInstruction = true hasSystemInstruction = true
} }
} }
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
} }
// contents // contents
@@ -121,32 +125,31 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Use GetThinkingText to handle wrapped thinking objects // Use GetThinkingText to handle wrapped thinking objects
thinkingText := util.GetThinkingText(contentResult) thinkingText := util.GetThinkingText(contentResult)
signatureResult := contentResult.Get("signature") signatureResult := contentResult.Get("signature")
clientSignature := "" clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" { if signatureResult.Exists() && signatureResult.String() != "" {
clientSignature = signatureResult.String() clientSignature = signatureResult.String()
}
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
} }
}
// Fallback to client signature only if cache miss and client signature is valid // Always try cached signature first (more reliable than client-provided)
if signature == "" && cache.HasValidSignature(clientSignature) { // Client may send stale or invalid signatures from different sessions
signature = clientSignature signature := ""
log.Debugf("Using client-provided signature for thinking block") if sessionID != "" && thinkingText != "" {
} if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
}
}
// Store for subsequent tool_use in the same message // Fallback to client signature only if cache miss and client signature is valid
if cache.HasValidSignature(signature) { if signature == "" && cache.HasValidSignature(clientSignature) {
currentMessageThinkingSignature = signature signature = clientSignature
} log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message // Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature) isUnsigned := !cache.HasValidSignature(signature)
@@ -321,6 +324,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// tools // tools
toolsJSON := "" toolsJSON := ""
toolDeclCount := 0 toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools") toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() { if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]` toolsJSON = `[{"functionDeclarations":[]}]`
@@ -333,10 +337,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw) inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict") for toolKey := range gjson.Parse(tool).Map() {
tool, _ = sjson.Delete(tool, "input_examples") if util.InArray(allowedToolKeys, toolKey) {
tool, _ = sjson.Delete(tool, "type") continue
tool, _ = sjson.Delete(tool, "cache_control") }
tool, _ = sjson.Delete(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool) toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++ toolDeclCount++
} }

View File

@@ -35,6 +35,7 @@ type Params struct {
CandidatesTokenCount int64 // Cached candidate token count from usage metadata CandidatesTokenCount int64 // Cached candidate token count from usage metadata
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
TotalTokenCount int64 // Cached total token count from usage metadata TotalTokenCount int64 // Cached total token count from usage metadata
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
HasSentFinalEvents bool // Indicates if final content/message events have been sent HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
@@ -98,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// This follows the Claude Code API specification for streaming message initialization // This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response // Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
@@ -270,7 +279,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int() params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int() params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int() params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
@@ -322,6 +332,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
*output = *output + "event: message_delta\n" *output = *output + "event: message_delta\n"
*output = *output + "data: " *output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens) delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n" *output = *output + delta + "\n\n\n"
params.HasSentFinalEvents = true params.HasSentFinalEvents = true
@@ -361,6 +379,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int() thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int() totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
outputTokens := candidateTokens + thoughtTokens outputTokens := candidateTokens + thoughtTokens
if outputTokens == 0 && totalTokens > 0 { if outputTokens == 0 && totalTokens > 0 {
outputTokens = totalTokens - promptTokens outputTokens = totalTokens - promptTokens
@@ -374,6 +393,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens) responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
contentArrayInitialized := false contentArrayInitialized := false
ensureContentArray := func() { ensureContentArray := func() {
@@ -452,7 +479,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.Set(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) { if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw) toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
} }

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -135,41 +134,31 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int ResponsesNeeded int
} }
// parseFunctionResponse attempts to unmarshal a function response part. // parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to gjson extraction if standard json.Unmarshal fails. // Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponse(response gjson.Result) map[string]interface{} { func parseFunctionResponseRaw(response gjson.Result) string {
var responseMap map[string]interface{} if response.IsObject() && gjson.Valid(response.Raw) {
err := json.Unmarshal([]byte(response.Raw), &responseMap) return response.Raw
if err == nil {
return responseMap
} }
log.Debugf("unmarshal function response failed, using fallback: %v", err) log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse") funcResp := response.Get("functionResponse")
if funcResp.Exists() { if funcResp.Exists() {
fr := map[string]interface{}{ fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
"name": funcResp.Get("name").String(), fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
"response": map[string]interface{}{ fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
"result": funcResp.Get("response").String(),
},
}
if id := funcResp.Get("id").String(); id != "" { if id := funcResp.Get("id").String(); id != "" {
fr["id"] = id fr, _ = sjson.Set(fr, "functionResponse.id", id)
} }
return map[string]interface{}{"functionResponse": fr} return fr
} }
return map[string]interface{}{ fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
"functionResponse": map[string]interface{}{ fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
"name": "unknown", return fr
"response": map[string]interface{}{"result": response.String()},
},
}
} }
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping. // fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -196,7 +185,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -228,17 +217,16 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Create merged function response content
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response)) partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
// Remove this group as it's been satisfied // Remove this group as it's been satisfied
@@ -252,50 +240,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
var functionCallsInThisModel []gjson.Result functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part) functionCallsCount++
} }
return true return true
}) })
if len(functionCallsInThisModel) > 0 { if functionCallsCount > 0 {
// Add the model content // Add the model content
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse model content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ModelContent: contentMap, ResponsesNeeded: functionCallsCount,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
// Regular model content without function calls // Regular model content without function calls
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
return true return true
@@ -307,25 +287,23 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response)) partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result := input
newContentsJSON, _ := json.Marshal(newContents) result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil return result, nil
} }

View File

@@ -192,6 +192,14 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -239,10 +247,30 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if role == "assistant" { } else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`) node := []byte(`{"role":"model","parts":[]}`)
p := 0 p := 0
if content.Type == gjson.String { if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++ p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -258,7 +286,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
fargs := tc.Get("function.arguments").String() fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) if gjson.Valid(fargs) {
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
} else {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++ p++
if fid != "" { if fid != "" {
@@ -293,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} }
} }
} }
@@ -319,7 +353,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -334,7 +368,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,12 +8,13 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -86,18 +87,27 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
} }
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
} }
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
}
} }
// Process the main content part of the response. // Process the main content part of the response.
@@ -171,21 +181,16 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }

View File

@@ -194,7 +194,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if name := fc.Get("name"); name.Exists() { if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String()) toolUse, _ = sjson.Set(toolUse, "name", name.String())
} }
if args := fc.Get("args"); args.Exists() { if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw) toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
} }
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse) msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
@@ -314,11 +314,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() { if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() { switch mode.String() {
case "AUTO": case "AUTO":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "NONE": case "NONE":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "none"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
case "ANY": case "ANY":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
} }
} }

View File

@@ -263,51 +263,6 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
} }
// convertArrayToJSON converts []interface{} to JSON array string
func convertArrayToJSON(arr []interface{}) string {
result := "[]"
for _, item := range arr {
switch itemData := item.(type) {
case map[string]interface{}:
itemJSON := convertMapToJSON(itemData)
result, _ = sjson.SetRaw(result, "-1", itemJSON)
case string:
result, _ = sjson.Set(result, "-1", itemData)
case bool:
result, _ = sjson.Set(result, "-1", itemData)
case float64, int, int64:
result, _ = sjson.Set(result, "-1", itemData)
default:
result, _ = sjson.Set(result, "-1", itemData)
}
}
return result
}
// convertMapToJSON converts map[string]interface{} to JSON object string
func convertMapToJSON(m map[string]interface{}) string {
result := "{}"
for key, value := range m {
switch val := value.(type) {
case map[string]interface{}:
nestedJSON := convertMapToJSON(val)
result, _ = sjson.SetRaw(result, key, nestedJSON)
case []interface{}:
arrayJSON := convertArrayToJSON(val)
result, _ = sjson.SetRaw(result, key, arrayJSON)
case string:
result, _ = sjson.Set(result, key, val)
case bool:
result, _ = sjson.Set(result, key, val)
case float64, int, int64:
result, _ = sjson.Set(result, key, val)
default:
result, _ = sjson.Set(result, key, val)
}
}
return result
}
// ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response. // ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response.
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible // This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all // JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
@@ -356,8 +311,8 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
} }
// Process each streaming event and collect parts // Process each streaming event and collect parts
var allParts []interface{} var allParts []string
var finalUsage map[string]interface{} var finalUsageJSON string
var responseID string var responseID string
var createdAt int64 var createdAt int64
@@ -407,16 +362,14 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if text := delta.Get("text"); text.Exists() && text.String() != "" { if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}` partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{}) allParts = append(allParts, partJSON)
allParts = append(allParts, part)
} }
case "thinking_delta": case "thinking_delta":
// Process reasoning/thinking content // Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" { if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}` partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{}) allParts = append(allParts, partJSON)
allParts = append(allParts, part)
} }
case "input_json_delta": case "input_json_delta":
// accumulate args partial_json for this index // accumulate args partial_json for this index
@@ -456,9 +409,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if argsTrim != "" { if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim) functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
} }
// Parse back to interface{} for allParts allParts = append(allParts, functionCallJSON)
functionCall := gjson.Parse(functionCallJSON).Value().(map[string]interface{})
allParts = append(allParts, functionCall)
// cleanup used state for this index // cleanup used state for this index
if newParam.ToolUseArgs != nil { if newParam.ToolUseArgs != nil {
delete(newParam.ToolUseArgs, idx) delete(newParam.ToolUseArgs, idx)
@@ -501,8 +452,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set traffic type (required by Gemini API) // Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT") usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
// Convert to map[string]interface{} using gjson finalUsageJSON = usageJSON
finalUsage = gjson.Parse(usageJSON).Value().(map[string]interface{})
} }
} }
} }
@@ -520,12 +470,16 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array // Set the consolidated parts array
if len(consolidatedParts) > 0 { if len(consolidatedParts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", convertToJSONString(consolidatedParts)) partsJSON := "[]"
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
} }
// Set usage metadata // Set usage metadata
if finalUsage != nil { if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", convertToJSONString(finalUsage)) template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
} }
return template return template
@@ -539,12 +493,12 @@ func GeminiTokenCount(ctx context.Context, count int64) string {
// This function processes the parts array to combine adjacent text elements and thinking elements // This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure. // into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements. // Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []interface{}) []interface{} { func consolidateParts(parts []string) []string {
if len(parts) == 0 { if len(parts) == 0 {
return parts return parts
} }
var consolidated []interface{} var consolidated []string
var currentTextPart strings.Builder var currentTextPart strings.Builder
var currentThoughtPart strings.Builder var currentThoughtPart strings.Builder
var hasText, hasThought bool var hasText, hasThought bool
@@ -554,8 +508,7 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasText && currentTextPart.Len() > 0 { if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}` textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String()) textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
textPart := gjson.Parse(textPartJSON).Value().(map[string]interface{}) consolidated = append(consolidated, textPartJSON)
consolidated = append(consolidated, textPart)
currentTextPart.Reset() currentTextPart.Reset()
hasText = false hasText = false
} }
@@ -566,42 +519,42 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasThought && currentThoughtPart.Len() > 0 { if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}` thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String()) thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
thoughtPart := gjson.Parse(thoughtPartJSON).Value().(map[string]interface{}) consolidated = append(consolidated, thoughtPartJSON)
consolidated = append(consolidated, thoughtPart)
currentThoughtPart.Reset() currentThoughtPart.Reset()
hasThought = false hasThought = false
} }
} }
for _, part := range parts { for _, partJSON := range parts {
partMap, ok := part.(map[string]interface{}) part := gjson.Parse(partJSON)
if !ok { if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part // Flush any pending parts and add this non-text part
flushText() flushText()
flushThought() flushThought()
consolidated = append(consolidated, part) consolidated = append(consolidated, partJSON)
continue continue
} }
if thought, isThought := partMap["thought"]; isThought && thought == true { thought := part.Get("thought")
if thought.Exists() && thought.Type == gjson.True {
// This is a thinking part - flush any pending text first // This is a thinking part - flush any pending text first
flushText() // Flush any pending text first flushText() // Flush any pending text first
if text, hasTextContent := partMap["text"].(string); hasTextContent { if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
currentThoughtPart.WriteString(text) currentThoughtPart.WriteString(text.String())
hasThought = true hasThought = true
} }
} else if text, hasTextContent := partMap["text"].(string); hasTextContent { } else if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
// This is a regular text part - flush any pending thought first // This is a regular text part - flush any pending thought first
flushThought() // Flush any pending thought first flushThought() // Flush any pending thought first
currentTextPart.WriteString(text) currentTextPart.WriteString(text.String())
hasText = true hasText = true
} else { } else {
// This is some other type of part (like function call) - flush both text and thought // This is some other type of part (like function call) - flush both text and thought
flushText() flushText()
flushThought() flushThought()
consolidated = append(consolidated, part) consolidated = append(consolidated, partJSON)
} }
} }
@@ -611,20 +564,3 @@ func consolidateParts(parts []interface{}) []interface{} {
return consolidated return consolidated
} }
// convertToJSONString converts interface{} to JSON string using sjson/gjson.
// This function provides a consistent way to serialize different data types to JSON strings
// for inclusion in the Gemini API response structure.
func convertToJSONString(v interface{}) string {
switch val := v.(type) {
case []interface{}:
return convertArrayToJSON(val)
case map[string]interface{}:
return convertMapToJSON(val)
default:
// For simple types, create a temporary JSON and extract the value
temp := `{"temp":null}`
temp, _ = sjson.Set(temp, "temp", val)
return gjson.Get(temp, "temp").Raw
}
}

View File

@@ -10,7 +10,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
@@ -137,9 +136,6 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.Set(out, "stream", stream)
// Process messages and transform them to Claude Code format // Process messages and transform them to Claude Code format
var anthropicMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() { if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, message gjson.Result) bool { messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String() role := message.Get("role").String()
@@ -152,33 +148,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
role = "user" role = "user"
} }
msg := map[string]interface{}{ msg := `{"role":"","content":[]}`
"role": role, msg, _ = sjson.Set(msg, "role", role)
"content": []interface{}{},
}
// Handle content based on its type (string or array) // Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" { if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
// Simple text content conversion part := `{"type":"text","text":""}`
msg["content"] = []interface{}{ part, _ = sjson.Set(part, "text", contentResult.String())
map[string]interface{}{ msg, _ = sjson.SetRaw(msg, "content.-1", part)
"type": "text",
"text": contentResult.String(),
},
}
} else if contentResult.Exists() && contentResult.IsArray() { } else if contentResult.Exists() && contentResult.IsArray() {
// Array of content parts processing
var contentParts []interface{}
contentResult.ForEach(func(_, part gjson.Result) bool { contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String() partType := part.Get("type").String()
switch partType { switch partType {
case "text": case "text":
// Text part conversion textPart := `{"type":"text","text":""}`
contentParts = append(contentParts, map[string]interface{}{ textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
"type": "text", msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
"text": part.Get("text").String(),
})
case "image_url": case "image_url":
// Convert OpenAI image format to Claude Code format // Convert OpenAI image format to Claude Code format
@@ -191,132 +177,95 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
mediaType := strings.TrimPrefix(mediaTypePart, "data:") mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1] data := parts[1]
contentParts = append(contentParts, map[string]interface{}{ imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
"type": "image", imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
"source": map[string]interface{}{ imagePart, _ = sjson.Set(imagePart, "source.data", data)
"type": "base64", msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
"media_type": mediaType,
"data": data,
},
})
} }
} }
} }
return true return true
}) })
if len(contentParts) > 0 {
msg["content"] = contentParts
}
} else {
// Initialize empty content array for tool calls
msg["content"] = []interface{}{}
} }
// Handle tool calls (for assistant messages) // Handle tool calls (for assistant messages)
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" { if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" {
var contentParts []interface{}
// Add existing text content if any
if existingContent, ok := msg["content"].([]interface{}); ok {
contentParts = existingContent
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
if toolCall.Get("type").String() == "function" { if toolCall.Get("type").String() == "function" {
toolCallID := toolCall.Get("id").String() toolCallID := toolCall.Get("id").String()
if toolCallID == "" { if toolCallID == "" {
toolCallID = genToolCallID() toolCallID = genToolCallID()
} }
toolCallIDs = append(toolCallIDs, toolCallID)
function := toolCall.Get("function") function := toolCall.Get("function")
toolUse := map[string]interface{}{ toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
"id": toolCallID, toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
"name": function.Get("name").String(),
}
// Parse arguments for the tool call // Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() { if args := function.Get("arguments"); args.Exists() {
argsStr := args.String() argsStr := args.String()
if argsStr != "" { if argsStr != "" && gjson.Valid(argsStr) {
var argsMap map[string]interface{} argsJSON := gjson.Parse(argsStr)
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil { if argsJSON.IsObject() {
toolUse["input"] = argsMap toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
contentParts = append(contentParts, toolUse) msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
} }
return true return true
}) })
msg["content"] = contentParts
} }
anthropicMessages = append(anthropicMessages, msg) out, _ = sjson.SetRaw(out, "messages.-1", msg)
case "tool": case "tool":
// Handle tool result messages conversion // Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String() toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String() content := message.Get("content").String()
// Create tool result message in Claude Code format msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg := map[string]interface{}{ msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
"role": "user", msg, _ = sjson.Set(msg, "content.0.content", content)
"content": []interface{}{ out, _ = sjson.SetRaw(out, "messages.-1", msg)
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolCallID,
"content": content,
},
},
}
anthropicMessages = append(anthropicMessages, msg)
} }
return true return true
}) })
} }
// Set messages in the output template
if len(anthropicMessages) > 0 {
messagesJSON, _ := json.Marshal(anthropicMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: OpenAI tools -> Claude Code tools // Tools mapping: OpenAI tools -> Claude Code tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
var anthropicTools []interface{} hasAnthropicTools := false
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" { if tool.Get("type").String() == "function" {
function := tool.Get("function") function := tool.Get("function")
anthropicTool := map[string]interface{}{ anthropicTool := `{"name":"","description":""}`
"name": function.Get("name").String(), anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
"description": function.Get("description").String(), anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
}
// Convert parameters schema for the tool // Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() { if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value() anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} else if parameters = function.Get("parametersJsonSchema"); parameters.Exists() { } else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value() anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} }
anthropicTools = append(anthropicTools, anthropicTool) out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
} }
return true return true
}) })
if len(anthropicTools) > 0 { if !hasAnthropicTools {
toolsJSON, _ := json.Marshal(anthropicTools) out, _ = sjson.Delete(out, "tools")
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
} }
} }
@@ -329,18 +278,17 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none": case "none":
// Don't set tool_choice, Claude Code will not use tools // Don't set tool_choice, Claude Code will not use tools
case "auto": case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "required": case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
case gjson.JSON: case gjson.JSON:
// Specific tool choice mapping // Specific tool choice mapping
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String() functionName := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{ toolChoiceJSON := `{"type":"tool","name":""}`
"type": "tool", toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
"name": functionName, out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
})
} }
default: default:
} }

View File

@@ -8,7 +8,7 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "fmt"
"strings" "strings"
"time" "time"
@@ -182,18 +182,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" { if arguments == "" {
arguments = "{}" arguments = "{}"
} }
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
toolCall := map[string]interface{}{ template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
"index": index, template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
"id": accumulator.ID, template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
"type": "function", template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
"function": map[string]interface{}{
"name": accumulator.Name,
"arguments": arguments,
},
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls", []interface{}{toolCall})
// Clean up the accumulator for this index // Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index) delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
@@ -214,12 +207,14 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts // Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ inputTokens := usage.Get("input_tokens").Int()
"prompt_tokens": usage.Get("input_tokens").Int(), outputTokens := usage.Get("output_tokens").Int()
"completion_tokens": usage.Get("output_tokens").Int(), cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
"total_tokens": usage.Get("input_tokens").Int() + usage.Get("output_tokens").Int(), cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
} template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage", usageObj) template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
return []string{template} return []string{template}
@@ -234,14 +229,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "error": case "error":
// Error event - format and return error response // Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() { if errorData := root.Get("error"); errorData.Exists() {
errorResponse := map[string]interface{}{ errorJSON := `{"error":{"message":"","type":""}}`
"error": map[string]interface{}{ errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
"message": errorData.Get("message").String(), errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
"type": errorData.Get("type").String(), return []string{errorJSON}
},
}
errorJSON, _ := json.Marshal(errorResponse)
return []string{string(errorJSON)}
} }
return []string{} return []string{}
@@ -297,15 +288,10 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var messageID string var messageID string
var model string var model string
var createdAt int64 var createdAt int64
var inputTokens, outputTokens int64
var reasoningTokens int64
var stopReason string var stopReason string
var contentParts []string var contentParts []string
var reasoningParts []string var reasoningParts []string
// Use map to track tool calls by index for proper merging toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
toolCallsMap := make(map[int]map[string]interface{})
// Track tool call arguments accumulation
toolCallArgsMap := make(map[int]strings.Builder)
for _, chunk := range chunks { for _, chunk := range chunks {
root := gjson.ParseBytes(chunk) root := gjson.ParseBytes(chunk)
@@ -318,9 +304,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String() messageID = message.Get("id").String()
model = message.Get("model").String() model = message.Get("model").String()
createdAt = time.Now().Unix() createdAt = time.Now().Unix()
if usage := message.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
} }
case "content_block_start": case "content_block_start":
@@ -331,18 +314,12 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Start of thinking/reasoning content - skip for now as it's handled in delta // Start of thinking/reasoning content - skip for now as it's handled in delta
continue continue
} else if blockType == "tool_use" { } else if blockType == "tool_use" {
// Initialize tool call tracking for this index // Initialize tool call accumulator for this index
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
toolCallsMap[index] = map[string]interface{}{ toolCallsAccumulator[index] = &ToolCallAccumulator{
"id": contentBlock.Get("id").String(), ID: contentBlock.Get("id").String(),
"type": "function", Name: contentBlock.Get("name").String(),
"function": map[string]interface{}{
"name": contentBlock.Get("name").String(),
"arguments": "",
},
} }
// Initialize arguments builder for this tool call
toolCallArgsMap[index] = strings.Builder{}
} }
} }
@@ -365,9 +342,8 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Accumulate tool call arguments // Accumulate tool call arguments
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() { if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
if builder, exists := toolCallArgsMap[index]; exists { if accumulator, exists := toolCallsAccumulator[index]; exists {
builder.WriteString(partialJSON.String()) accumulator.Arguments.WriteString(partialJSON.String())
toolCallArgsMap[index] = builder
} }
} }
} }
@@ -376,14 +352,9 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
case "content_block_stop": case "content_block_stop":
// Finalize tool call arguments for this index when content block ends // Finalize tool call arguments for this index when content block ends
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
if toolCall, exists := toolCallsMap[index]; exists { if accumulator, exists := toolCallsAccumulator[index]; exists {
if builder, argsExists := toolCallArgsMap[index]; argsExists { if accumulator.Arguments.Len() == 0 {
// Set the accumulated arguments for the tool call accumulator.Arguments.WriteString("{}")
arguments := builder.String()
if arguments == "" {
arguments = "{}"
}
toolCall["function"].(map[string]interface{})["arguments"] = arguments
} }
} }
@@ -395,11 +366,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
} }
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int() inputTokens := usage.Get("input_tokens").Int()
// Estimate reasoning tokens from accumulated thinking content outputTokens := usage.Get("output_tokens").Int()
if len(reasoningParts) > 0 { cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
} out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
} }
} }
} }
@@ -421,24 +395,35 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
// Set tool calls if any were accumulated during processing // Set tool calls if any were accumulated during processing
if len(toolCallsMap) > 0 { if len(toolCallsAccumulator) > 0 {
// Convert tool calls map to array, preserving order by index toolCallsCount := 0
var toolCallsArray []interface{}
// Find the maximum index to determine the range
maxIndex := -1 maxIndex := -1
for index := range toolCallsMap { for index := range toolCallsAccumulator {
if index > maxIndex { if index > maxIndex {
maxIndex = index maxIndex = index
} }
} }
// Iterate through all possible indices up to maxIndex
for i := 0; i <= maxIndex; i++ { for i := 0; i <= maxIndex; i++ {
if toolCall, exists := toolCallsMap[i]; exists { accumulator, exists := toolCallsAccumulator[i]
toolCallsArray = append(toolCallsArray, toolCall) if !exists {
continue
} }
arguments := accumulator.Arguments.String()
idPath := fmt.Sprintf("choices.0.message.tool_calls.%d.id", toolCallsCount)
typePath := fmt.Sprintf("choices.0.message.tool_calls.%d.type", toolCallsCount)
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
toolCallsCount++
} }
if len(toolCallsArray) > 0 { if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.message.tool_calls", toolCallsArray)
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls") out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
} else { } else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
@@ -447,16 +432,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
} }
// Set usage information including prompt tokens, completion tokens, and total tokens
totalTokens := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
// Add reasoning tokens to usage details if any reasoning content was processed
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
}
return out return out
} }

View File

@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var builder strings.Builder var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() { if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String() textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" { if builder.Len() > 0 && text != "" {
builder.WriteByte('\n') builder.WriteByte('\n')
} }
builder.WriteString(text) builder.WriteString(text)
return true return true
}) })
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
} }
instructionsText = builder.String() instructionsText = builder.String()
if instructionsText != "" { if instructionsText != "" {
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
} }
return true return true
}) })
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
} }
// Fallback to given role if content types not decisive // Fallback to given role if content types not decisive
@@ -254,7 +259,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
toolUse, _ = sjson.Set(toolUse, "id", callID) toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name) toolUse, _ = sjson.Set(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) { if argsStr != "" && gjson.Valid(argsStr) {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsStr) argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
}
} }
asst := `{"role":"assistant","content":[]}` asst := `{"role":"assistant","content":[]}`
@@ -309,16 +317,18 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String: case gjson.String:
switch toolChoice.String() { switch toolChoice.String() {
case "auto": case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "none": case "none":
// Leave unset; implies no tools // Leave unset; implies no tools
case "required": case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
case gjson.JSON: case gjson.JSON:
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String() fn := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "tool", "name": fn}) toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
} }
default: default:

View File

@@ -344,31 +344,20 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
// Build response.output from aggregated state // Build response.output from aggregated state
var outputs []interface{} outputsWrapper := `{"arr":[]}`
// reasoning item (if any) // reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded { if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
r := map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningItemID, item, _ = sjson.Set(item, "id", st.ReasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
outputs = append(outputs, r)
} }
// assistant message item (if any text) // assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" { if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
m := map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": st.CurrentMsgID, item, _ = sjson.Set(item, "id", st.CurrentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
}
outputs = append(outputs, m)
} }
// function_call items (in ascending index order for determinism) // function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
@@ -395,19 +384,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" { if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID callID = st.CurrentFCID
} }
item := map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
}
outputs = append(outputs, item)
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
reasoningTokens := int64(0) reasoningTokens := int64(0)
@@ -628,27 +614,18 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Build output array // Build output array
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if reasoningBuf.Len() > 0 { if reasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": reasoningItemID, item, _ = sjson.Set(item, "id", reasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": reasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
})
} }
if currentMsgID != "" || textBuf.Len() > 0 { if currentMsgID != "" || textBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": currentMsgID, item, _ = sjson.Set(item, "id", currentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", textBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": textBuf.String(),
}},
"role": "assistant",
})
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
// Preserve index order // Preserve index order
@@ -669,18 +646,16 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" { if args == "" {
args = "{}" args = "{}"
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", st.id), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", st.id)
"arguments": args, item, _ = sjson.Set(item, "name", st.name)
"call_id": st.id, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": st.name,
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.Set(out, "output", outputs) out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// Usage // Usage

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@@ -191,21 +190,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
return "" return ""
} }
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": responseData.Get("id").String(), out, _ = sjson.Set(out, "id", responseData.Get("id").String())
"type": "message", out, _ = sjson.Set(out, "model", responseData.Get("model").String())
"role": "assistant", out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
"model": responseData.Get("model").String(), out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": responseData.Get("usage.input_tokens").Int(),
"output_tokens": responseData.Get("usage.output_tokens").Int(),
},
}
var contentBlocks []interface{}
hasToolCall := false hasToolCall := false
if output := responseData.Get("output"); output.Exists() && output.IsArray() { if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -244,10 +234,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} }
} }
if thinkingBuilder.Len() > 0 { if thinkingBuilder.Len() > 0 {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
case "message": case "message":
if content := item.Get("content"); content.Exists() { if content := item.Get("content"); content.Exists() {
@@ -256,10 +245,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" { if part.Get("type").String() == "output_text" {
text := part.Get("text").String() text := part.Get("text").String()
if text != "" { if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", text)
"text": text, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
return true return true
@@ -267,10 +255,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else { } else {
text := content.String() text := content.String()
if text != "" { if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", text)
"text": text, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
} }
@@ -281,54 +268,41 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original name = original
} }
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
"id": item.Get("call_id").String(), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
} argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
if argsStr := item.Get("arguments").String(); argsStr != "" { inputRaw = argsJSON.Raw
var args interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolBlock["input"] = args
} }
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
contentBlocks = append(contentBlocks, toolBlock) out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
} }
return true return true
}) })
} }
if len(contentBlocks) > 0 {
response["content"] = contentBlocks
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" { if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
response["stop_reason"] = stopReason.String() out, _ = sjson.Set(out, "stop_reason", stopReason.String())
} else if hasToolCall { } else if hasToolCall {
response["stop_reason"] = "tool_use" out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else { } else {
response["stop_reason"] = "end_turn" out, _ = sjson.Set(out, "stop_reason", "end_turn")
} }
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" { if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
response["stop_sequence"] = stopSequence.Value() out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
} }
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() { if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
response["usage"] = map[string]interface{}{ out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
"input_tokens": responseData.Get("usage.input_tokens").Int(), out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
"output_tokens": responseData.Get("usage.output_tokens").Int(),
}
} }
responseJSON, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(responseJSON)
} }
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools. // buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"time" "time"
@@ -190,19 +189,19 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
} }
// Process output content to build parts array // Process output content to build parts array
var parts []interface{}
hasToolCall := false hasToolCall := false
var pendingFunctionCalls []interface{} var pendingFunctionCalls []string
flushPendingFunctionCalls := func() { flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) > 0 { if len(pendingFunctionCalls) == 0 {
// Add all pending function calls as individual parts return
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
parts = append(parts, fc)
}
pendingFunctionCalls = nil
} }
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
} }
if output := responseData.Get("output"); output.Exists() && output.IsArray() { if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -216,11 +215,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content // Add thinking content
if content := value.Get("content"); content.Exists() { if content := value.Get("content"); content.Exists() {
part := map[string]interface{}{ part := `{"text":"","thought":true}`
"thought": true, part, _ = sjson.Set(part, "text", content.String())
"text": content.String(), template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
parts = append(parts, part)
} }
case "message": case "message":
@@ -232,10 +229,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool { content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" { if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() { if text := contentItem.Get("text"); text.Exists() {
part := map[string]interface{}{ part := `{"text":""}`
"text": text.String(), part, _ = sjson.Set(part, "text", text.String())
} template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
parts = append(parts, part)
} }
} }
return true return true
@@ -245,28 +241,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call": case "function_call":
// Collect function call for potential merging with consecutive ones // Collect function call for potential merging with consecutive ones
hasToolCall = true hasToolCall = true
functionCall := map[string]interface{}{ functionCall := `{"functionCall":{"args":{},"name":""}}`
"functionCall": map[string]interface{}{ {
"name": func() string { n := value.Get("name").String()
n := value.Get("name").String() rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON) if orig, ok := rev[n]; ok {
if orig, ok := rev[n]; ok { n = orig
return orig }
} functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
return n
}(),
"args": map[string]interface{}{},
},
} }
// Parse and set arguments // Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" { if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr) argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() { if argsResult.IsObject() {
var args map[string]interface{} functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
functionCall["functionCall"].(map[string]interface{})["args"] = args
}
} }
} }
@@ -279,11 +268,6 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
flushPendingFunctionCalls() flushPendingFunctionCalls()
} }
// Set the parts array
if len(parts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts))
}
// Set finish reason based on whether there were tool calls // Set finish reason based on whether there were tool calls
if hasToolCall { if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
@@ -323,15 +307,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev return rev
} }
// mustMarshalJSON marshals a value to JSON, panicking on error.
func mustMarshalJSON(v interface{}) string {
data, err := json.Marshal(v)
if err != nil {
return ""
}
return string(data)
}
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
} }

View File

@@ -7,10 +7,8 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -41,92 +39,102 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction // system instruction
var systemInstruction *client.Content if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemResult := gjson.GetBytes(rawJSON, "system") systemInstruction := `{"role":"user","parts":[]}`
if systemResult.IsArray() { hasSystemParts := false
systemResults := systemResult.Array() systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} if systemPromptResult.Get("type").String() == "text" {
for i := 0; i < len(systemResults); i++ { textResult := systemPromptResult.Get("text")
systemPromptResult := systemResults[i] if textResult.Type == gjson.String {
systemTypePromptResult := systemPromptResult.Get("type") part := `{"text":""}`
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { part, _ = sjson.Set(part, "text", textResult.String())
systemPrompt := systemPromptResult.Get("text").String() systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
systemPart := client.Part{Text: systemPrompt} hasSystemParts = true
systemInstruction.Parts = append(systemInstruction.Parts, systemPart) }
} }
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
} }
if len(systemInstruction.Parts) == 0 { } else if systemResult.Type == gjson.String {
systemInstruction = nil out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
} }
// contents // contents
contents := make([]client.Content, 0) if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue return true
} }
role := roleResult.String() role := roleResult.String()
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
for j := 0; j < len(contentResults); j++ { switch contentResult.Get("type").String() {
contentResult := contentResults[j] case "text":
contentTypeResult := contentResult.Get("type") part := `{"text":""}`
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
prompt := contentResult.Get("text").String() contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { case "tool_use":
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
var args map[string]any argsResult := gjson.Parse(functionArgs)
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { if argsResult.IsObject() && gjson.Valid(functionArgs) {
clientContent.Parts = append(clientContent.Parts, client.Part{ part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
FunctionCall: &client.FunctionCall{Name: functionName, Args: args}, part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
ThoughtSignature: geminiCLIClaudeThoughtSignature, part, _ = sjson.Set(part, "functionCall.name", functionName)
}) part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID == "" {
funcName := toolCallID return true
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
} }
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} return true
contents = append(contents, clientContent) })
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
prompt := contentsResult.String() part := `{"text":""}`
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} }
} return true
})
} }
// tools // tools
var tools []client.ToolDeclaration if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
toolsResult := gjson.GetBytes(rawJSON, "tools") hasTools := false
if toolsResult.IsArray() { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := inputSchemaResult.Raw
@@ -136,30 +144,19 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if !hasTools {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
} }
} }
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "request.tools")
} }
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "request.contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "request.tools", string(b))
} }
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -276,22 +275,16 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": root.Get("response.responseId").String(), out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
"type": "message", out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
"role": "assistant",
"model": root.Get("response.modelVersion").String(), inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
"content": []interface{}{}, outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
"stop_reason": nil, out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
"stop_sequence": nil, out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
"usage": map[string]interface{}{
"input_tokens": root.Get("response.usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int(),
},
}
parts := root.Get("response.candidates.0.content.parts") parts := root.Get("response.candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{} textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{} thinkingBuilder := strings.Builder{}
toolIDCounter := 0 toolIDCounter := 0
@@ -301,10 +294,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -312,10 +304,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
@@ -339,21 +330,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
name := functionCall.Get("name").String() name := functionCall.Get("name").String()
toolIDCounter++ toolIDCounter++
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
"id": fmt.Sprintf("tool_%d", toolIDCounter), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
if args := functionCall.Get("args"); args.Exists() { out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
continue continue
} }
} }
@@ -362,8 +347,6 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
flushThinking() flushThinking()
flushText() flushText()
response["content"] = contentBlocks
stopReason := "end_turn" stopReason := "end_turn"
if hasToolCall { if hasToolCall {
stopReason = "tool_use" stopReason = "tool_use"
@@ -379,19 +362,13 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
} }
} }
} }
response["stop_reason"] = stopReason out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) { if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() { out, _ = sjson.Delete(out, "usage")
delete(response, "usage")
}
} }
encoded, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(encoded)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -117,8 +116,6 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int ResponsesNeeded int
} }
@@ -146,7 +143,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -178,23 +175,17 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Create merged function response content
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
var responseMap map[string]interface{} if !response.IsObject() {
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) log.Warnf("failed to parse function response")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue continue
} }
responseParts = append(responseParts, responseMap) functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
// Remove this group as it's been satisfied // Remove this group as it's been satisfied
@@ -208,50 +199,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
var functionCallsInThisModel []gjson.Result functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part) functionCallsCount++
} }
return true return true
}) })
if len(functionCallsInThisModel) > 0 { if functionCallsCount > 0 {
// Add the model content // Add the model content
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse model content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ModelContent: contentMap, ResponsesNeeded: functionCallsCount,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
// Regular model content without function calls // Regular model content without function calls
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
return true return true
@@ -263,31 +246,24 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
var responseMap map[string]interface{} if !response.IsObject() {
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) log.Warnf("failed to parse function response")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue continue
} }
responseParts = append(responseParts, responseMap) functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result := input
newContentsJSON, _ := json.Marshal(newContents) result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil return result, nil
} }

View File

@@ -160,6 +160,14 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -210,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if content.Type == gjson.String { if content.Type == gjson.String {
// Assistant text -> single model content // Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++ p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -236,7 +265,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function // Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`) toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0 pp := 0
for _, fid := range fIDs { for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok { if name, ok := tcID2Name[fid]; ok {
@@ -252,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} }
} }
} }
@@ -278,7 +309,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -293,7 +324,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,7 +8,6 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -171,21 +170,16 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }

View File

@@ -7,10 +7,8 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -34,92 +32,102 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction // system instruction
var systemInstruction *client.Content if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemResult := gjson.GetBytes(rawJSON, "system") systemInstruction := `{"role":"user","parts":[]}`
if systemResult.IsArray() { hasSystemParts := false
systemResults := systemResult.Array() systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} if systemPromptResult.Get("type").String() == "text" {
for i := 0; i < len(systemResults); i++ { textResult := systemPromptResult.Get("text")
systemPromptResult := systemResults[i] if textResult.Type == gjson.String {
systemTypePromptResult := systemPromptResult.Get("type") part := `{"text":""}`
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { part, _ = sjson.Set(part, "text", textResult.String())
systemPrompt := systemPromptResult.Get("text").String() systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
systemPart := client.Part{Text: systemPrompt} hasSystemParts = true
systemInstruction.Parts = append(systemInstruction.Parts, systemPart) }
} }
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
} }
if len(systemInstruction.Parts) == 0 { } else if systemResult.Type == gjson.String {
systemInstruction = nil out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
}
} }
// contents // contents
contents := make([]client.Content, 0) if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue return true
} }
role := roleResult.String() role := roleResult.String()
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
for j := 0; j < len(contentResults); j++ { switch contentResult.Get("type").String() {
contentResult := contentResults[j] case "text":
contentTypeResult := contentResult.Get("type") part := `{"text":""}`
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
prompt := contentResult.Get("text").String() contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { case "tool_use":
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
var args map[string]any argsResult := gjson.Parse(functionArgs)
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { if argsResult.IsObject() && gjson.Valid(functionArgs) {
clientContent.Parts = append(clientContent.Parts, client.Part{ part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
FunctionCall: &client.FunctionCall{Name: functionName, Args: args}, part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature)
ThoughtSignature: geminiClaudeThoughtSignature, part, _ = sjson.Set(part, "functionCall.name", functionName)
}) part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID == "" {
funcName := toolCallID return true
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
} }
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} return true
contents = append(contents, clientContent) })
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
prompt := contentsResult.String() part := `{"text":""}`
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} }
} return true
})
} }
// tools // tools
var tools []client.ToolDeclaration if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
toolsResult := gjson.GetBytes(rawJSON, "tools") hasTools := false
if toolsResult.IsArray() { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := inputSchemaResult.Raw
@@ -129,30 +137,19 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if !hasTools {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool)
} }
} }
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "tools")
} }
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "system_instruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "tools", string(b))
} }
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -282,22 +281,16 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": root.Get("responseId").String(), out, _ = sjson.Set(out, "id", root.Get("responseId").String())
"type": "message", out, _ = sjson.Set(out, "model", root.Get("modelVersion").String())
"role": "assistant",
"model": root.Get("modelVersion").String(), inputTokens := root.Get("usageMetadata.promptTokenCount").Int()
"content": []interface{}{}, outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int()
"stop_reason": nil, out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
"stop_sequence": nil, out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
"usage": map[string]interface{}{
"input_tokens": root.Get("usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int(),
},
}
parts := root.Get("candidates.0.content.parts") parts := root.Get("candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{} textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{} thinkingBuilder := strings.Builder{}
toolIDCounter := 0 toolIDCounter := 0
@@ -307,10 +300,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -318,10 +310,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
@@ -345,21 +336,15 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
name := functionCall.Get("name").String() name := functionCall.Get("name").String()
toolIDCounter++ toolIDCounter++
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
"id": fmt.Sprintf("tool_%d", toolIDCounter), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
if args := functionCall.Get("args"); args.Exists() { out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
continue continue
} }
} }
@@ -368,8 +353,6 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
flushThinking() flushThinking()
flushText() flushText()
response["content"] = contentBlocks
stopReason := "end_turn" stopReason := "end_turn"
if hasToolCall { if hasToolCall {
stopReason = "tool_use" stopReason = "tool_use"
@@ -385,19 +368,13 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
} }
} }
} }
response["stop_reason"] = stopReason out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) { if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("usageMetadata").Exists() {
if usageMeta := root.Get("usageMetadata"); !usageMeta.Exists() { out, _ = sjson.Delete(out, "usage")
delete(response, "usage")
}
} }
encoded, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(encoded)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -178,6 +178,14 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user") out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "system_instruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "system_instruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -225,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if role == "assistant" { } else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`) node := []byte(`{"role":"model","parts":[]}`)
p := 0 p := 0
if content.Type == gjson.String { if content.Type == gjson.String {
// Assistant text -> single model content // Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
p++ p++
} else if content.IsArray() { } else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts // Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() { for _, item := range content.Array() {
switch item.Get("type").String() { switch item.Get("type").String() {
case "text": case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++ p++
case "image_url": case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity. // If the assistant returned an inline data URL, preserve it for history fidelity.
@@ -253,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} }
} }
} }
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
} }
// Tool calls -> single model content with functionCall parts // Tool calls -> single model content with functionCall parts
@@ -278,7 +282,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRawBytes(out, "contents.-1", node) out, _ = sjson.SetRawBytes(out, "contents.-1", node)
// Append a single tool content combining name + response per function // Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`) toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0 pp := 0
for _, fid := range fIDs { for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok { if name, ok := tcID2Name[fid]; ok {
@@ -294,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if pp > 0 { if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode) out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
} }
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
} }
} }
} }
@@ -320,7 +326,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -335,7 +341,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,12 +8,12 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -89,18 +89,27 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set usage metadata (token counts). // Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() { if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
} }
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
} }
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
}
}
} }
// Process the main content part of the response. // Process the main content part of the response.
@@ -173,21 +182,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }
@@ -248,10 +252,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
promptTokenCount := usageResult.Get("promptTokenCount").Int() promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 { if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
} }
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
}
}
} }
// Process the main content part of the response. // Process the main content part of the response.
@@ -305,21 +318,16 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{
"type": "image_url",
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.message.images") imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
} }
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
} }
} }
} }

View File

@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
MsgIndex int MsgIndex int
CurrentMsgID string CurrentMsgID string
TextBuf strings.Builder TextBuf strings.Builder
ItemTextBuf strings.Builder
// reasoning aggregation // reasoning aggregation
ReasoningOpened bool ReasoningOpened bool
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID) partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex) partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded)) out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
st.ItemTextBuf.WriteString(t.String())
} }
st.TextBuf.WriteString(t.String()) st.TextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}` msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
finalizeReasoning() finalizeReasoning()
// Close message output if opened // Close message output if opened
if st.MsgOpened { if st.MsgOpened {
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq()) done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID) done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex) done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done)) out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID) partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex) partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone)) out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}` final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq()) final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex) final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID) final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final)) out = append(out, emitEvent("response.output_item.done", final))
} }
@@ -377,27 +384,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
} }
// Compose outputs in encountered order: reasoning, message, function_calls // Compose outputs in encountered order: reasoning, message, function_calls
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if st.ReasoningOpened { if st.ReasoningOpened {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningItemID, item, _ = sjson.Set(item, "id", st.ReasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
})
} }
if st.MsgOpened { if st.MsgOpened {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": st.CurrentMsgID, item, _ = sjson.Set(item, "id", st.CurrentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
})
} }
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf)) idxs := make([]int, 0, len(st.FuncArgsBuf))
@@ -416,18 +414,16 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
if b := st.FuncArgsBuf[idx]; b != nil { if b := st.FuncArgsBuf[idx]; b != nil {
args = b.String() args = b.String()
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
"arguments": args, item, _ = sjson.Set(item, "name", st.FuncNames[idx])
"call_id": st.FuncCallIDs[idx], outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": st.FuncNames[idx],
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// usage mapping // usage mapping
@@ -558,11 +554,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Build outputs from candidates[0].content.parts // Build outputs from candidates[0].content.parts
var outputs []interface{}
var reasoningText strings.Builder var reasoningText strings.Builder
var reasoningEncrypted string var reasoningEncrypted string
var messageText strings.Builder var messageText strings.Builder
var haveMessage bool var haveMessage bool
haveOutput := false
ensureOutput := func() {
if haveOutput {
return
}
resp, _ = sjson.SetRaw(resp, "output", "[]")
haveOutput = true
}
appendOutput := func(itemJSON string) {
ensureOutput()
resp, _ = sjson.SetRaw(resp, "output.-1", itemJSON)
}
if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() { if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, p gjson.Result) bool { parts.ForEach(func(_, p gjson.Result) bool {
if p.Get("thought").Bool() { if p.Get("thought").Bool() {
@@ -583,19 +592,16 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
name := fc.Get("name").String() name := fc.Get("name").String()
args := fc.Get("args") args := fc.Get("args")
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
outputs = append(outputs, map[string]interface{}{ itemJSON := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", itemJSON, _ = sjson.Set(itemJSON, "call_id", callID)
"status": "completed", itemJSON, _ = sjson.Set(itemJSON, "name", name)
"arguments": func() string { argsStr := ""
if args.Exists() { if args.Exists() {
return args.Raw argsStr = args.Raw
} }
return "" itemJSON, _ = sjson.Set(itemJSON, "arguments", argsStr)
}(), appendOutput(itemJSON)
"call_id": callID,
"name": name,
})
return true return true
} }
return true return true
@@ -605,42 +611,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// Reasoning output item // Reasoning output item
if reasoningText.Len() > 0 || reasoningEncrypted != "" { if reasoningText.Len() > 0 || reasoningEncrypted != "" {
rid := strings.TrimPrefix(id, "resp_") rid := strings.TrimPrefix(id, "resp_")
item := map[string]interface{}{ itemJSON := `{"id":"","type":"reasoning","encrypted_content":""}`
"id": fmt.Sprintf("rs_%s", rid), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("rs_%s", rid))
"type": "reasoning", itemJSON, _ = sjson.Set(itemJSON, "encrypted_content", reasoningEncrypted)
"encrypted_content": reasoningEncrypted,
}
var summaries []interface{}
if reasoningText.Len() > 0 { if reasoningText.Len() > 0 {
summaries = append(summaries, map[string]interface{}{ summaryJSON := `{"type":"summary_text","text":""}`
"type": "summary_text", summaryJSON, _ = sjson.Set(summaryJSON, "text", reasoningText.String())
"text": reasoningText.String(), itemJSON, _ = sjson.SetRaw(itemJSON, "summary", "[]")
}) itemJSON, _ = sjson.SetRaw(itemJSON, "summary.-1", summaryJSON)
} }
if summaries != nil { appendOutput(itemJSON)
item["summary"] = summaries
}
outputs = append(outputs, item)
} }
// Assistant message output item // Assistant message output item
if haveMessage { if haveMessage {
outputs = append(outputs, map[string]interface{}{ itemJSON := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")))
"type": "message", itemJSON, _ = sjson.Set(itemJSON, "content.0.text", messageText.String())
"status": "completed", appendOutput(itemJSON)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": messageText.String(),
}},
"role": "assistant",
})
}
if len(outputs) > 0 {
resp, _ = sjson.Set(resp, "output", outputs)
} }
// usage mapping // usage mapping

View File

@@ -10,6 +10,7 @@ import (
// MergeAdjacentMessages merges adjacent messages with the same role. // MergeAdjacentMessages merges adjacent messages with the same role.
// This reduces API call complexity and improves compatibility. // This reduces API call complexity and improves compatibility.
// Based on AIClient-2-API implementation. // Based on AIClient-2-API implementation.
// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result { func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
if len(messages) <= 1 { if len(messages) <= 1 {
return messages return messages
@@ -26,6 +27,12 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
currentRole := msg.Get("role").String() currentRole := msg.Get("role").String()
lastRole := lastMsg.Get("role").String() lastRole := lastMsg.Get("role").String()
// Don't merge tool messages - each has a unique tool_call_id
if currentRole == "tool" || lastRole == "tool" {
merged = append(merged, msg)
continue
}
if currentRole == lastRole { if currentRole == lastRole {
// Merge content from current message into last message // Merge content from current message into last message
mergedContent := mergeMessageContent(lastMsg, msg) mergedContent := mergeMessageContent(lastMsg, msg)

View File

@@ -450,24 +450,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Merge adjacent messages with the same role // Merge adjacent messages with the same role
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array()) messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
// Build tool_call_id to name mapping from assistant messages // Track pending tool results that should be attached to the next user message
toolCallIDToName := make(map[string]string) // This is critical for LiteLLM-translated requests where tool results appear
for _, msg := range messagesArray { // as separate "tool" role messages between assistant and user messages
if msg.Get("role").String() == "assistant" { var pendingToolResults []KiroToolResult
toolCalls := msg.Get("tool_calls")
if toolCalls.IsArray() {
for _, tc := range toolCalls.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
toolCallIDToName[id] = name
}
}
}
}
}
}
for i, msg := range messagesArray { for i, msg := range messagesArray {
role := msg.Get("role").String() role := msg.Get("role").String()
@@ -480,6 +466,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "user": case "user":
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin) userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
// Merge any pending tool results from preceding "tool" role messages
toolResults = append(pendingToolResults, toolResults...)
pendingToolResults = nil // Reset pending tool results
if isLastMessage { if isLastMessage {
currentUserMsg = &userMsg currentUserMsg = &userMsg
currentToolResults = toolResults currentToolResults = toolResults
@@ -505,6 +495,24 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "assistant": case "assistant":
assistantMsg := buildAssistantMessageFromOpenAI(msg) assistantMsg := buildAssistantMessageFromOpenAI(msg)
// If there are pending tool results, we need to insert a synthetic user message
// before this assistant message to maintain proper conversation structure
if len(pendingToolResults) > 0 {
syntheticUserMsg := KiroUserInputMessage{
Content: "Tool results provided.",
ModelID: modelID,
Origin: origin,
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: pendingToolResults,
},
}
history = append(history, KiroHistoryMessage{
UserInputMessage: &syntheticUserMsg,
})
pendingToolResults = nil
}
if isLastMessage { if isLastMessage {
history = append(history, KiroHistoryMessage{ history = append(history, KiroHistoryMessage{
AssistantResponseMessage: &assistantMsg, AssistantResponseMessage: &assistantMsg,
@@ -524,7 +532,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "tool": case "tool":
// Tool messages in OpenAI format provide results for tool_calls // Tool messages in OpenAI format provide results for tool_calls
// These are typically followed by user or assistant messages // These are typically followed by user or assistant messages
// Process them and merge into the next user message's tool results // Collect them as pending and attach to the next user message
toolCallID := msg.Get("tool_call_id").String() toolCallID := msg.Get("tool_call_id").String()
content := msg.Get("content").String() content := msg.Get("content").String()
@@ -534,9 +542,21 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
Content: []KiroTextContent{{Text: content}}, Content: []KiroTextContent{{Text: content}},
Status: "success", Status: "success",
} }
// Tool results should be included in the next user message // Collect pending tool results to attach to the next user message
// For now, collect them and they'll be handled when we build the current message pendingToolResults = append(pendingToolResults, toolResult)
currentToolResults = append(currentToolResults, toolResult) }
}
}
// Handle case where tool results are at the end with no following user message
if len(pendingToolResults) > 0 {
currentToolResults = append(currentToolResults, pendingToolResults...)
// If there's no current user message, create a synthetic one for the tool results
if currentUserMsg == nil {
currentUserMsg = &KiroUserInputMessage{
Content: "Tool results provided.",
ModelID: modelID,
Origin: origin,
} }
} }
} }
@@ -551,9 +571,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
var toolResults []KiroToolResult var toolResults []KiroToolResult
var images []KiroImage var images []KiroImage
// Track seen toolCallIds to deduplicate
seenToolCallIDs := make(map[string]bool)
if content.IsArray() { if content.IsArray() {
for _, part := range content.Array() { for _, part := range content.Array() {
partType := part.Get("type").String() partType := part.Get("type").String()
@@ -589,9 +606,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
contentBuilder.WriteString(content.String()) contentBuilder.WriteString(content.String())
} }
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
_ = seenToolCallIDs // Used for deduplication if needed
userMsg := KiroUserInputMessage{ userMsg := KiroUserInputMessage{
Content: contentBuilder.String(), Content: contentBuilder.String(),
ModelID: modelID, ModelID: modelID,

View File

@@ -0,0 +1,386 @@
package openai
import (
"encoding/json"
"testing"
)
// TestToolResultsAttachedToCurrentMessage verifies that tool results from "tool" role messages
// are properly attached to the current user message (the last message in the conversation).
// This is critical for LiteLLM-translated requests where tool results appear as separate messages.
func TestToolResultsAttachedToCurrentMessage(t *testing.T) {
// OpenAI format request simulating LiteLLM's translation from Anthropic format
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user
// The last user message should have the tool results attached
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello, can you read a file for me?"},
{
"role": "assistant",
"content": "I'll read that file for you.",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc123",
"content": "File contents: Hello World!"
},
{"role": "user", "content": "What did the file say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// The last user message becomes currentMessage
// History should have: user (first), assistant (with tool_calls)
t.Logf("History count: %d", len(payload.ConversationState.History))
if len(payload.ConversationState.History) != 2 {
t.Errorf("Expected 2 history entries (user + assistant), got %d", len(payload.ConversationState.History))
}
// Tool results should be attached to currentMessage (the last user message)
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx == nil {
t.Fatal("Expected currentMessage to have UserInputMessageContext with tool results")
}
if len(ctx.ToolResults) != 1 {
t.Fatalf("Expected 1 tool result in currentMessage, got %d", len(ctx.ToolResults))
}
tr := ctx.ToolResults[0]
if tr.ToolUseID != "call_abc123" {
t.Errorf("Expected toolUseId 'call_abc123', got '%s'", tr.ToolUseID)
}
if len(tr.Content) == 0 || tr.Content[0].Text != "File contents: Hello World!" {
t.Errorf("Tool result content mismatch, got: %+v", tr.Content)
}
}
// TestToolResultsInHistoryUserMessage verifies that when there are multiple user messages
// after tool results, the tool results are attached to the correct user message in history.
func TestToolResultsInHistoryUserMessage(t *testing.T) {
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user
// The first user after tool should have tool results in history
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"content": "I'll read the file.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "File result"
},
{"role": "user", "content": "Thanks for the file"},
{"role": "assistant", "content": "You're welcome"},
{"role": "user", "content": "Bye"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// History should have: user, assistant, user (with tool results), assistant
// CurrentMessage should be: last user "Bye"
t.Logf("History count: %d", len(payload.ConversationState.History))
// Find the user message in history with tool results
foundToolResults := false
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil {
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
if h.UserInputMessage.UserInputMessageContext != nil {
if len(h.UserInputMessage.UserInputMessageContext.ToolResults) > 0 {
foundToolResults = true
t.Logf(" Found %d tool results", len(h.UserInputMessage.UserInputMessageContext.ToolResults))
tr := h.UserInputMessage.UserInputMessageContext.ToolResults[0]
if tr.ToolUseID != "call_1" {
t.Errorf("Expected toolUseId 'call_1', got '%s'", tr.ToolUseID)
}
}
}
}
if h.AssistantResponseMessage != nil {
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
}
}
if !foundToolResults {
t.Error("Tool results were not attached to any user message in history")
}
}
// TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls
func TestToolResultsWithMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read two files for me"},
{
"role": "assistant",
"content": "I'll read both files.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/file1.txt\"}"
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/file2.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "Content of file 1"
},
{
"role": "tool",
"tool_call_id": "call_2",
"content": "Content of file 2"
},
{"role": "user", "content": "What do they say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
t.Logf("History count: %d", len(payload.ConversationState.History))
t.Logf("CurrentMessage content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content)
// Check if there are any tool results anywhere
var totalToolResults int
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
t.Logf("History[%d] user message has %d tool results", i, count)
totalToolResults += count
}
}
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx != nil {
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
totalToolResults += len(ctx.ToolResults)
} else {
t.Logf("CurrentMessage has no UserInputMessageContext")
}
if totalToolResults != 2 {
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
}
}
// TestToolResultsAtEndOfConversation verifies tool results are handled when
// the conversation ends with tool results (no following user message)
func TestToolResultsAtEndOfConversation(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read a file"},
{
"role": "assistant",
"content": "Reading the file.",
"tool_calls": [
{
"id": "call_end",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_end",
"content": "File contents here"
}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// When the last message is a tool result, a synthetic user message is created
// and tool results should be attached to it
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx == nil || len(ctx.ToolResults) == 0 {
t.Error("Expected tool results to be attached to current message when conversation ends with tool result")
} else {
if ctx.ToolResults[0].ToolUseID != "call_end" {
t.Errorf("Expected toolUseId 'call_end', got '%s'", ctx.ToolResults[0].ToolUseID)
}
}
}
// TestToolResultsFollowedByAssistant verifies handling when tool results are followed
// by an assistant message (no intermediate user message).
// This is the pattern from LiteLLM translation of Anthropic format where:
// user message has ONLY tool_result blocks -> LiteLLM creates tool messages
// then the next message is assistant
func TestToolResultsFollowedByAssistant(t *testing.T) {
// Sequence: user -> assistant (with tool_calls) -> tool -> tool -> assistant -> user
// This simulates LiteLLM's translation of:
// user: "Read files"
// assistant: [tool_use, tool_use]
// user: [tool_result, tool_result] <- becomes multiple "tool" role messages
// assistant: "I've read them"
// user: "What did they say?"
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read two files for me"},
{
"role": "assistant",
"content": "I'll read both files.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/a.txt\"}"
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/b.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "Contents of file A"
},
{
"role": "tool",
"tool_call_id": "call_2",
"content": "Contents of file B"
},
{
"role": "assistant",
"content": "I've read both files."
},
{"role": "user", "content": "What did they say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
t.Logf("History count: %d", len(payload.ConversationState.History))
// Tool results should be attached to a synthetic user message or the history should be valid
var totalToolResults int
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil {
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
if h.UserInputMessage.UserInputMessageContext != nil {
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
t.Logf(" Has %d tool results", count)
totalToolResults += count
}
}
if h.AssistantResponseMessage != nil {
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
}
}
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx != nil {
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
totalToolResults += len(ctx.ToolResults)
}
if totalToolResults != 2 {
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
}
}
// TestAssistantEndsConversation verifies handling when assistant is the last message
func TestAssistantEndsConversation(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"content": "Hi there!"
}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// When assistant is last, a "Continue" user message should be created
if payload.ConversationState.CurrentMessage.UserInputMessage.Content == "" {
t.Error("Expected a 'Continue' message to be created when assistant is last")
}
}

View File

@@ -7,7 +7,6 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -138,11 +137,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Convert input to arguments JSON string // Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() { if input := part.Get("input"); input.Exists() {
if inputJSON, err := json.Marshal(input.Value()); err == nil { toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", string(inputJSON))
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
} else { } else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}") toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
} }
@@ -191,8 +186,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Emit tool calls in a separate assistant message // Emit tool calls in a separate assistant message
if role == "assistant" && len(toolCalls) > 0 { if role == "assistant" && len(toolCalls) > 0 {
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}` toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
toolCallsJSON, _ := json.Marshal(toolCalls) toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
toolCallMsgJSON, _ = sjson.SetRaw(toolCallMsgJSON, "tool_calls", string(toolCallsJSON))
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value()) messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
} }

View File

@@ -8,7 +8,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@@ -133,24 +132,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
if delta := root.Get("choices.0.delta"); delta.Exists() { if delta := root.Get("choices.0.delta"); delta.Exists() {
if !param.MessageStarted { if !param.MessageStarted {
// Send message_start event // Send message_start event
messageStart := map[string]interface{}{ messageStartJSON := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
"type": "message_start", messageStartJSON, _ = sjson.Set(messageStartJSON, "message.id", param.MessageID)
"message": map[string]interface{}{ messageStartJSON, _ = sjson.Set(messageStartJSON, "message.model", param.Model)
"id": param.MessageID, results = append(results, "event: message_start\ndata: "+messageStartJSON+"\n\n")
"type": "message",
"role": "assistant",
"model": param.Model,
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n")
param.MessageStarted = true param.MessageStarted = true
// Don't send content_block_start for text here - wait for actual content // Don't send content_block_start for text here - wait for actual content
@@ -168,29 +153,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.ThinkingContentBlockIndex = param.NextContentBlockIndex param.ThinkingContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++ param.NextContentBlockIndex++
} }
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"content_block": map[string]interface{}{
"type": "thinking",
"thinking": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
param.ThinkingContentBlockStarted = true param.ThinkingContentBlockStarted = true
} }
thinkingDelta := map[string]interface{}{ thinkingDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
"type": "content_block_delta", thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "delta.thinking", reasoningText)
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+thinkingDeltaJSON+"\n\n")
"type": "thinking_delta",
"thinking": reasoningText,
},
}
thinkingDeltaJSON, _ := json.Marshal(thinkingDelta)
results = append(results, "event: content_block_delta\ndata: "+string(thinkingDeltaJSON)+"\n\n")
} }
} }
@@ -203,29 +175,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.TextContentBlockIndex = param.NextContentBlockIndex param.TextContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++ param.NextContentBlockIndex++
} }
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"content_block": map[string]interface{}{
"type": "text",
"text": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
param.TextContentBlockStarted = true param.TextContentBlockStarted = true
} }
contentDelta := map[string]interface{}{ contentDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
"type": "content_block_delta", contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "delta.text", content.String())
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+contentDeltaJSON+"\n\n")
"type": "text_delta",
"text": content.String(),
},
}
contentDeltaJSON, _ := json.Marshal(contentDelta)
results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n")
// Accumulate content // Accumulate content
param.ContentAccumulator.WriteString(content.String()) param.ContentAccumulator.WriteString(content.String())
@@ -263,18 +222,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
stopTextContentBlock(param, &results) stopTextContentBlock(param, &results)
// Send content_block_start for tool_use // Send content_block_start for tool_use
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
"index": blockIndex, contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
"content_block": map[string]interface{}{ contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
"type": "tool_use", results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"id": accumulator.ID,
"name": accumulator.Name,
"input": map[string]interface{}{},
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
} }
// Handle function arguments // Handle function arguments
@@ -298,12 +250,9 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_stop for thinking content if needed // Send content_block_stop for thinking content if needed
if param.ThinkingContentBlockStarted { if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -319,24 +268,15 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send complete input_json_delta with all accumulated arguments // Send complete input_json_delta with all accumulated arguments
if accumulator.Arguments.Len() > 0 { if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{ inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
"type": "content_block_delta", inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
"index": blockIndex, inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
"index": blockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
delete(param.ToolCallBlockIndexes, index) delete(param.ToolCallBlockIndexes, index)
} }
param.ContentBlocksStopped = true param.ContentBlocksStopped = true
@@ -361,20 +301,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
} }
} }
// Send message_delta with usage // Send message_delta with usage
messageDelta := map[string]interface{}{ messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
"type": "message_delta", messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
"delta": map[string]interface{}{ messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason), messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
"stop_sequence": nil, results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
},
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
param.MessageDeltaSent = true param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results) emitMessageStopIfNeeded(param, &results)
@@ -390,12 +321,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// Ensure all content blocks are stopped before final events // Ensure all content blocks are stopped before final events
if param.ThinkingContentBlockStarted { if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -408,24 +336,15 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
blockIndex := param.toolContentBlockIndex(index) blockIndex := param.toolContentBlockIndex(index)
if accumulator.Arguments.Len() > 0 { if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{ inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
"type": "content_block_delta", inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
"index": blockIndex, inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
"index": blockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
delete(param.ToolCallBlockIndexes, index) delete(param.ToolCallBlockIndexes, index)
} }
param.ContentBlocksStopped = true param.ContentBlocksStopped = true
@@ -433,16 +352,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// If we haven't sent message_delta yet (no usage info was received), send it now // If we haven't sent message_delta yet (no usage info was received), send it now
if param.FinishReason != "" && !param.MessageDeltaSent { if param.FinishReason != "" && !param.MessageDeltaSent {
messageDelta := map[string]interface{}{ messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null}}`
"type": "message_delta", messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
"delta": map[string]interface{}{ results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
"stop_sequence": nil,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
param.MessageDeltaSent = true param.MessageDeltaSent = true
} }
@@ -455,105 +367,73 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string { func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
// Build Anthropic response out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
response := map[string]interface{}{ out, _ = sjson.Set(out, "id", root.Get("id").String())
"id": root.Get("id").String(), out, _ = sjson.Set(out, "model", root.Get("model").String())
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
// Process message content and tool calls // Process message content and tool calls
var contentBlocks []interface{} if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
choice := choices.Array()[0] // Take first choice choice := choices.Array()[0] // Take first choice
reasoningNode := choice.Get("message.reasoning_content")
allReasoning := collectOpenAIReasoningTexts(reasoningNode)
for _, reasoningText := range allReasoning { reasoningNode := choice.Get("message.reasoning_content")
for _, reasoningText := range collectOpenAIReasoningTexts(reasoningNode) {
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", reasoningText)
"thinking": reasoningText, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
// Handle text content // Handle text content
if content := choice.Get("message.content"); content.Exists() && content.String() != "" { if content := choice.Get("message.content"); content.Exists() && content.String() != "" {
textBlock := map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", content.String())
"text": content.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
}
contentBlocks = append(contentBlocks, textBlock)
} }
// Handle tool calls // Handle tool calls
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := map[string]interface{}{ toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
"id": toolCall.Get("id").String(), toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
"name": toolCall.Get("function.name").String(),
}
// Parse arguments argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
argsStr := toolCall.Get("function.arguments").String() if argsStr != "" && gjson.Valid(argsStr) {
argsStr = util.FixJSON(argsStr) argsJSON := gjson.Parse(argsStr)
if argsStr != "" { if argsJSON.IsObject() {
var args interface{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolUseBlock["input"] = args
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUseBlock) out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true return true
}) })
} }
// Set stop reason // Set stop reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String()) out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
} }
} }
response["content"] = contentBlocks
// Set usage information // Set usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
response["usage"] = map[string]interface{}{ out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
"input_tokens": usage.Get("prompt_tokens").Int(), out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
"output_tokens": usage.Get("completion_tokens").Int(), reasoningTokens := int64(0)
"reasoning_tokens": func() int64 { if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() { reasoningTokens = v.Int()
return v.Int()
}
return 0
}(),
}
} else {
response["usage"] = map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
} }
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
} }
responseJSON, _ := json.Marshal(response) return []string{out}
return []string{string(responseJSON)}
} }
// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents // mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents
@@ -620,12 +500,9 @@ func stopThinkingContentBlock(param *ConvertOpenAIResponseToAnthropicParams, res
if !param.ThinkingContentBlockStarted { if !param.ThinkingContentBlockStarted {
return return
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, *results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -642,12 +519,9 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
if !param.TextContentBlockStarted { if !param.TextContentBlockStarted {
return return
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, *results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.TextContentBlockStarted = false param.TextContentBlockStarted = false
param.TextContentBlockIndex = -1 param.TextContentBlockIndex = -1
} }
@@ -667,29 +541,19 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
_ = requestRawJSON _ = requestRawJSON
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
response := map[string]interface{}{
"id": root.Get("id").String(),
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
contentBlocks := make([]interface{}, 0)
hasToolCall := false hasToolCall := false
stopReasonSet := false
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 { if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0] choice := choices.Array()[0]
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String()) out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
stopReasonSet = true
} }
if message := choice.Get("message"); message.Exists() { if message := choice.Get("message"); message.Exists() {
@@ -702,10 +566,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -713,16 +576,14 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
for _, item := range contentResult.Array() { for _, item := range contentResult.Array() {
typeStr := item.Get("type").String() switch item.Get("type").String() {
switch typeStr {
case "text": case "text":
flushThinking() flushThinking()
textBuilder.WriteString(item.Get("text").String()) textBuilder.WriteString(item.Get("text").String())
@@ -733,25 +594,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if toolCalls.IsArray() { if toolCalls.IsArray() {
toolCalls.ForEach(func(_, tc gjson.Result) bool { toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUse := map[string]interface{}{ toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
"id": tc.Get("id").String(), toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
"name": tc.Get("function.name").String(),
}
argsStr := util.FixJSON(tc.Get("function.arguments").String()) argsStr := util.FixJSON(tc.Get("function.arguments").String())
if argsStr != "" { if argsStr != "" && gjson.Valid(argsStr) {
var parsed interface{} argsJSON := gjson.Parse(argsStr)
if err := json.Unmarshal([]byte(argsStr), &parsed); err == nil { if argsJSON.IsObject() {
toolUse["input"] = parsed toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUse) out, _ = sjson.SetRaw(out, "content.-1", toolUse)
return true return true
}) })
} }
@@ -771,10 +630,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
} else if contentResult.Type == gjson.String { } else if contentResult.Type == gjson.String {
textContent := contentResult.String() textContent := contentResult.String()
if textContent != "" { if textContent != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textContent)
"text": textContent, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
} }
@@ -784,81 +642,52 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", reasoningText)
"thinking": reasoningText, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUseBlock := map[string]interface{}{ toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
"id": toolCall.Get("id").String(), toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
"name": toolCall.Get("function.name").String(),
}
argsStr := toolCall.Get("function.arguments").String() argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
argsStr = util.FixJSON(argsStr) if argsStr != "" && gjson.Valid(argsStr) {
if argsStr != "" { argsJSON := gjson.Parse(argsStr)
var args interface{} if argsJSON.IsObject() {
if err := json.Unmarshal([]byte(argsStr), &args); err == nil { toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
toolUseBlock["input"] = args
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUseBlock) out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true return true
}) })
} }
} }
} }
response["content"] = contentBlocks
if respUsage := root.Get("usage"); respUsage.Exists() { if respUsage := root.Get("usage"); respUsage.Exists() {
usageJSON := `{}` out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
usageJSON, _ = sjson.Set(usageJSON, "input_tokens", respUsage.Get("prompt_tokens").Int()) out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
usageJSON, _ = sjson.Set(usageJSON, "output_tokens", respUsage.Get("completion_tokens").Int())
parsedUsage := gjson.Parse(usageJSON).Value().(map[string]interface{})
response["usage"] = parsedUsage
} else {
response["usage"] = `{"input_tokens":0,"output_tokens":0}`
} }
if response["stop_reason"] == nil { if !stopReasonSet {
if hasToolCall { if hasToolCall {
response["stop_reason"] = "tool_use" out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else { } else {
response["stop_reason"] = "end_turn" out, _ = sjson.Set(out, "stop_reason", "end_turn")
} }
} }
if !hasToolCall { return out
if toolBlocks := response["content"].([]interface{}); len(toolBlocks) > 0 {
for _, block := range toolBlocks {
if m, ok := block.(map[string]interface{}); ok && m["type"] == "tool_use" {
hasToolCall = true
break
}
}
}
if hasToolCall {
response["stop_reason"] = "tool_use"
}
}
responseJSON, err := json.Marshal(response)
if err != nil {
return ""
}
return string(responseJSON)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -8,7 +8,6 @@ package gemini
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
@@ -94,7 +93,6 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.Set(out, "stream", stream)
// Process contents (Gemini messages) -> OpenAI messages // Process contents (Gemini messages) -> OpenAI messages
var openAIMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results var toolCallIDs []string // Track tool call IDs for matching with tool results
// System instruction -> OpenAI system message // System instruction -> OpenAI system message
@@ -105,22 +103,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
} }
if systemInstruction.Exists() { if systemInstruction.Exists() {
parts := systemInstruction.Get("parts") parts := systemInstruction.Get("parts")
msg := map[string]interface{}{ msg := `{"role":"system","content":[]}`
"role": "system", hasContent := false
"content": []interface{}{},
}
var aggregatedParts []interface{}
if parts.Exists() && parts.IsArray() { if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
// Handle text parts // Handle text parts
if text := part.Get("text"); text.Exists() { if text := part.Get("text"); text.Exists() {
formattedText := text.String() contentPart := `{"type":"text","text":""}`
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart, _ = sjson.Set(contentPart, "text", text.String())
"type": "text", msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
"text": formattedText, hasContent = true
})
} }
// Handle inline data (e.g., images) // Handle inline data (e.g., images)
@@ -132,20 +125,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String() data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"image_url","image_url":{"url":""}}`
"type": "image_url", contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
"image_url": map[string]interface{}{ msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
"url": imageURL, hasContent = true
},
})
} }
return true return true
}) })
} }
if len(aggregatedParts) > 0 { if hasContent {
msg["content"] = aggregatedParts out, _ = sjson.SetRaw(out, "messages.-1", msg)
openAIMessages = append(openAIMessages, msg)
} }
} }
@@ -159,16 +149,15 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
role = "assistant" role = "assistant"
} }
// Create OpenAI message msg := `{"role":"","content":""}`
msg := map[string]interface{}{ msg, _ = sjson.Set(msg, "role", role)
"role": role,
"content": "",
}
var textBuilder strings.Builder var textBuilder strings.Builder
var aggregatedParts []interface{} contentWrapper := `{"arr":[]}`
contentPartsCount := 0
onlyTextContent := true onlyTextContent := true
var toolCalls []interface{} toolCallsWrapper := `{"arr":[]}`
toolCallsCount := 0
if parts.Exists() && parts.IsArray() { if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
@@ -176,10 +165,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if text := part.Get("text"); text.Exists() { if text := part.Get("text"); text.Exists() {
formattedText := text.String() formattedText := text.String()
textBuilder.WriteString(formattedText) textBuilder.WriteString(formattedText)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"text","text":""}`
"type": "text", contentPart, _ = sjson.Set(contentPart, "text", formattedText)
"text": formattedText, contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
}) contentPartsCount++
} }
// Handle inline data (e.g., images) // Handle inline data (e.g., images)
@@ -193,12 +182,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String() data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"image_url","image_url":{"url":""}}`
"type": "image_url", contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
"image_url": map[string]interface{}{ contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
"url": imageURL, contentPartsCount++
},
})
} }
// Handle function calls (Gemini) -> tool calls (OpenAI) // Handle function calls (Gemini) -> tool calls (OpenAI)
@@ -206,44 +193,32 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
toolCallID := genToolCallID() toolCallID := genToolCallID()
toolCallIDs = append(toolCallIDs, toolCallID) toolCallIDs = append(toolCallIDs, toolCallID)
toolCall := map[string]interface{}{ toolCall := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
"id": toolCallID, toolCall, _ = sjson.Set(toolCall, "id", toolCallID)
"type": "function", toolCall, _ = sjson.Set(toolCall, "function.name", functionCall.Get("name").String())
"function": map[string]interface{}{
"name": functionCall.Get("name").String(),
},
}
// Convert args to arguments JSON string // Convert args to arguments JSON string
if args := functionCall.Get("args"); args.Exists() { if args := functionCall.Get("args"); args.Exists() {
argsJSON, _ := json.Marshal(args.Value()) toolCall, _ = sjson.Set(toolCall, "function.arguments", args.Raw)
toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
} else { } else {
toolCall["function"].(map[string]interface{})["arguments"] = "{}" toolCall, _ = sjson.Set(toolCall, "function.arguments", "{}")
} }
toolCalls = append(toolCalls, toolCall) toolCallsWrapper, _ = sjson.SetRaw(toolCallsWrapper, "arr.-1", toolCall)
toolCallsCount++
} }
// Handle function responses (Gemini) -> tool role messages (OpenAI) // Handle function responses (Gemini) -> tool role messages (OpenAI)
if functionResponse := part.Get("functionResponse"); functionResponse.Exists() { if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// Create tool message for function response // Create tool message for function response
toolMsg := map[string]interface{}{ toolMsg := `{"role":"tool","tool_call_id":"","content":""}`
"role": "tool",
"tool_call_id": "", // Will be set based on context
"content": "",
}
// Convert response.content to JSON string // Convert response.content to JSON string
if response := functionResponse.Get("response"); response.Exists() { if response := functionResponse.Get("response"); response.Exists() {
if content = response.Get("content"); content.Exists() { if contentField := response.Get("content"); contentField.Exists() {
// Use the content field from the response toolMsg, _ = sjson.Set(toolMsg, "content", contentField.Raw)
contentJSON, _ := json.Marshal(content.Value())
toolMsg["content"] = string(contentJSON)
} else { } else {
// Fallback to entire response toolMsg, _ = sjson.Set(toolMsg, "content", response.Raw)
responseJSON, _ := json.Marshal(response.Value())
toolMsg["content"] = string(responseJSON)
} }
} }
@@ -252,13 +227,13 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if len(toolCallIDs) > 0 { if len(toolCallIDs) > 0 {
// Use the last tool call ID (simple matching by function name) // Use the last tool call ID (simple matching by function name)
// In a real implementation, you might want more sophisticated matching // In a real implementation, you might want more sophisticated matching
toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1] toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", toolCallIDs[len(toolCallIDs)-1])
} else { } else {
// Generate a tool call ID if none available // Generate a tool call ID if none available
toolMsg["tool_call_id"] = genToolCallID() toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", genToolCallID())
} }
openAIMessages = append(openAIMessages, toolMsg) out, _ = sjson.SetRaw(out, "messages.-1", toolMsg)
} }
return true return true
@@ -266,170 +241,46 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
} }
// Set content // Set content
if len(aggregatedParts) > 0 { if contentPartsCount > 0 {
if onlyTextContent { if onlyTextContent {
msg["content"] = textBuilder.String() msg, _ = sjson.Set(msg, "content", textBuilder.String())
} else { } else {
msg["content"] = aggregatedParts msg, _ = sjson.SetRaw(msg, "content", gjson.Get(contentWrapper, "arr").Raw)
} }
} }
// Set tool calls if any // Set tool calls if any
if len(toolCalls) > 0 { if toolCallsCount > 0 {
msg["tool_calls"] = toolCalls msg, _ = sjson.SetRaw(msg, "tool_calls", gjson.Get(toolCallsWrapper, "arr").Raw)
} }
openAIMessages = append(openAIMessages, msg) out, _ = sjson.SetRaw(out, "messages.-1", msg)
// switch role {
// case "user", "model":
// // Convert role: model -> assistant
// if role == "model" {
// role = "assistant"
// }
//
// // Create OpenAI message
// msg := map[string]interface{}{
// "role": role,
// "content": "",
// }
//
// var contentParts []string
// var toolCalls []interface{}
//
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle text parts
// if text := part.Get("text"); text.Exists() {
// contentParts = append(contentParts, text.String())
// }
//
// // Handle function calls (Gemini) -> tool calls (OpenAI)
// if functionCall := part.Get("functionCall"); functionCall.Exists() {
// toolCallID := genToolCallID()
// toolCallIDs = append(toolCallIDs, toolCallID)
//
// toolCall := map[string]interface{}{
// "id": toolCallID,
// "type": "function",
// "function": map[string]interface{}{
// "name": functionCall.Get("name").String(),
// },
// }
//
// // Convert args to arguments JSON string
// if args := functionCall.Get("args"); args.Exists() {
// argsJSON, _ := json.Marshal(args.Value())
// toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
// } else {
// toolCall["function"].(map[string]interface{})["arguments"] = "{}"
// }
//
// toolCalls = append(toolCalls, toolCall)
// }
//
// return true
// })
// }
//
// // Set content
// if len(contentParts) > 0 {
// msg["content"] = strings.Join(contentParts, "")
// }
//
// // Set tool calls if any
// if len(toolCalls) > 0 {
// msg["tool_calls"] = toolCalls
// }
//
// openAIMessages = append(openAIMessages, msg)
//
// case "function":
// // Handle Gemini function role -> OpenAI tool role
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle function responses (Gemini) -> tool role messages (OpenAI)
// if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// // Create tool message for function response
// toolMsg := map[string]interface{}{
// "role": "tool",
// "tool_call_id": "", // Will be set based on context
// "content": "",
// }
//
// // Convert response.content to JSON string
// if response := functionResponse.Get("response"); response.Exists() {
// if content = response.Get("content"); content.Exists() {
// // Use the content field from the response
// contentJSON, _ := json.Marshal(content.Value())
// toolMsg["content"] = string(contentJSON)
// } else {
// // Fallback to entire response
// responseJSON, _ := json.Marshal(response.Value())
// toolMsg["content"] = string(responseJSON)
// }
// }
//
// // Try to match with previous tool call ID
// _ = functionResponse.Get("name").String() // functionName not used for now
// if len(toolCallIDs) > 0 {
// // Use the last tool call ID (simple matching by function name)
// // In a real implementation, you might want more sophisticated matching
// toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
// } else {
// // Generate a tool call ID if none available
// toolMsg["tool_call_id"] = genToolCallID()
// }
//
// openAIMessages = append(openAIMessages, toolMsg)
// }
//
// return true
// })
// }
// }
return true return true
}) })
} }
// Set messages
if len(openAIMessages) > 0 {
messagesJSON, _ := json.Marshal(openAIMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: Gemini tools -> OpenAI tools // Tools mapping: Gemini tools -> OpenAI tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var openAITools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() { if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() {
functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool { functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool {
openAITool := map[string]interface{}{ openAITool := `{"type":"function","function":{"name":"","description":""}}`
"type": "function", openAITool, _ = sjson.Set(openAITool, "function.name", funcDecl.Get("name").String())
"function": map[string]interface{}{ openAITool, _ = sjson.Set(openAITool, "function.description", funcDecl.Get("description").String())
"name": funcDecl.Get("name").String(),
"description": funcDecl.Get("description").String(),
},
}
// Convert parameters schema // Convert parameters schema
if parameters := funcDecl.Get("parameters"); parameters.Exists() { if parameters := funcDecl.Get("parameters"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} else if parameters = funcDecl.Get("parametersJsonSchema"); parameters.Exists() { } else if parameters := funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} }
openAITools = append(openAITools, openAITool) out, _ = sjson.SetRaw(out, "tools.-1", openAITool)
return true return true
}) })
} }
return true return true
}) })
if len(openAITools) > 0 {
toolsJSON, _ := json.Marshal(openAITools)
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
}
} }
// Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it) // Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it)

View File

@@ -8,7 +8,6 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -84,15 +83,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
template, _ = sjson.Set(template, "model", model.String()) template, _ = sjson.Set(template, "model", model.String())
} }
usageObj := map[string]interface{}{ template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
template, _ = sjson.Set(template, "usageMetadata", usageObj)
return []string{template} return []string{template}
} }
return []string{} return []string{}
@@ -133,13 +129,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
continue continue
} }
reasoningTemplate := baseTemplate reasoningTemplate := baseTemplate
parts := []interface{}{ reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.thought", true)
map[string]interface{}{ reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.text", reasoningText)
"thought": true,
"text": reasoningText,
},
}
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts", parts)
chunkOutputs = append(chunkOutputs, reasoningTemplate) chunkOutputs = append(chunkOutputs, reasoningTemplate)
} }
} }
@@ -150,13 +141,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
(*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText) (*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText)
// Create text part for this delta // Create text part for this delta
parts := []interface{}{
map[string]interface{}{
"text": contentText,
},
}
contentTemplate := baseTemplate contentTemplate := baseTemplate
contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts", parts) contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts.0.text", contentText)
chunkOutputs = append(chunkOutputs, contentTemplate) chunkOutputs = append(chunkOutputs, contentTemplate)
} }
@@ -225,24 +211,13 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// If we have accumulated tool calls, output them now // If we have accumulated tool calls, output them now
if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 { if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 {
var parts []interface{} partIndex := 0
for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator { for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator {
argsStr := accumulator.Arguments.String() namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
var argsMap map[string]interface{} argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
template, _ = sjson.Set(template, namePath, accumulator.Name)
argsMap = parseArgsToMap(argsStr) template, _ = sjson.SetRaw(template, argsPath, parseArgsToObjectRaw(accumulator.Arguments.String()))
partIndex++
functionCallPart := map[string]interface{}{
"functionCall": map[string]interface{}{
"name": accumulator.Name,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
}
if len(parts) > 0 {
template, _ = sjson.Set(template, "candidates.0.content.parts", parts)
} }
// Clear accumulators // Clear accumulators
@@ -255,15 +230,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// Handle usage information // Handle usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
template, _ = sjson.Set(template, "usageMetadata", usageObj)
results = append(results, template) results = append(results, template)
return true return true
} }
@@ -291,46 +263,54 @@ func mapOpenAIFinishReasonToGemini(openAIReason string) string {
} }
} }
// parseArgsToMap safely parses a JSON string of function arguments into a map. // parseArgsToObjectRaw safely parses a JSON string of function arguments into an object JSON string.
// It returns an empty map if the input is empty or cannot be parsed as a JSON object. // It returns "{}" if the input is empty or cannot be parsed as a JSON object.
func parseArgsToMap(argsStr string) map[string]interface{} { func parseArgsToObjectRaw(argsStr string) string {
trimmed := strings.TrimSpace(argsStr) trimmed := strings.TrimSpace(argsStr)
if trimmed == "" || trimmed == "{}" { if trimmed == "" || trimmed == "{}" {
return map[string]interface{}{} return "{}"
} }
// First try strict JSON // First try strict JSON
var out map[string]interface{} if gjson.Valid(trimmed) {
if errUnmarshal := json.Unmarshal([]byte(trimmed), &out); errUnmarshal == nil { strict := gjson.Parse(trimmed)
return out if strict.IsObject() {
return strict.Raw
}
} }
// Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius) // Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius)
tolerant := tolerantParseJSONMap(trimmed) tolerant := tolerantParseJSONObjectRaw(trimmed)
if len(tolerant) > 0 { if tolerant != "{}" {
return tolerant return tolerant
} }
// Fallback: return empty object when parsing fails // Fallback: return empty object when parsing fails
return map[string]interface{}{} return "{}"
} }
// tolerantParseJSONMap attempts to parse a JSON-like object string into a map, tolerating func escapeSjsonPathKey(key string) string {
key = strings.ReplaceAll(key, `\`, `\\`)
key = strings.ReplaceAll(key, `.`, `\.`)
return key
}
// tolerantParseJSONObjectRaw attempts to parse a JSON-like object string into a JSON object string, tolerating
// bareword values (unquoted strings) commonly seen during streamed tool calls. // bareword values (unquoted strings) commonly seen during streamed tool calls.
// Example input: {"location": 北京, "unit": celsius} // Example input: {"location": 北京, "unit": celsius}
func tolerantParseJSONMap(s string) map[string]interface{} { func tolerantParseJSONObjectRaw(s string) string {
// Ensure we operate within the outermost braces if present // Ensure we operate within the outermost braces if present
start := strings.Index(s, "{") start := strings.Index(s, "{")
end := strings.LastIndex(s, "}") end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || start >= end { if start == -1 || end == -1 || start >= end {
return map[string]interface{}{} return "{}"
} }
content := s[start+1 : end] content := s[start+1 : end]
runes := []rune(content) runes := []rune(content)
n := len(runes) n := len(runes)
i := 0 i := 0
result := make(map[string]interface{}) result := "{}"
for i < n { for i < n {
// Skip whitespace and commas // Skip whitespace and commas
@@ -356,6 +336,7 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
break break
} }
keyName := jsonStringTokenToRawString(keyToken) keyName := jsonStringTokenToRawString(keyToken)
sjsonKey := escapeSjsonPathKey(keyName)
i = nextIdx i = nextIdx
// Skip whitespace // Skip whitespace
@@ -375,17 +356,16 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
} }
// Parse value (string, number, object/array, bareword) // Parse value (string, number, object/array, bareword)
var value interface{}
switch runes[i] { switch runes[i] {
case '"': case '"':
// JSON string // JSON string
valToken, ni := parseJSONStringRunes(runes, i) valToken, ni := parseJSONStringRunes(runes, i)
if ni == -1 { if ni == -1 {
// Malformed; treat as empty string // Malformed; treat as empty string
value = "" result, _ = sjson.Set(result, sjsonKey, "")
i = n i = n
} else { } else {
value = jsonStringTokenToRawString(valToken) result, _ = sjson.Set(result, sjsonKey, jsonStringTokenToRawString(valToken))
i = ni i = ni
} }
case '{', '[': case '{', '[':
@@ -394,11 +374,10 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
if ni == -1 { if ni == -1 {
i = n i = n
} else { } else {
var anyVal interface{} if gjson.Valid(seg) {
if errUnmarshal := json.Unmarshal([]byte(seg), &anyVal); errUnmarshal == nil { result, _ = sjson.SetRaw(result, sjsonKey, seg)
value = anyVal
} else { } else {
value = seg result, _ = sjson.Set(result, sjsonKey, seg)
} }
i = ni i = ni
} }
@@ -411,21 +390,19 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
token := strings.TrimSpace(string(runes[i:j])) token := strings.TrimSpace(string(runes[i:j]))
// Interpret common JSON atoms and numbers; otherwise treat as string // Interpret common JSON atoms and numbers; otherwise treat as string
if token == "true" { if token == "true" {
value = true result, _ = sjson.Set(result, sjsonKey, true)
} else if token == "false" { } else if token == "false" {
value = false result, _ = sjson.Set(result, sjsonKey, false)
} else if token == "null" { } else if token == "null" {
value = nil result, _ = sjson.Set(result, sjsonKey, nil)
} else if numVal, ok := tryParseNumber(token); ok { } else if numVal, ok := tryParseNumber(token); ok {
value = numVal result, _ = sjson.Set(result, sjsonKey, numVal)
} else { } else {
value = token result, _ = sjson.Set(result, sjsonKey, token)
} }
i = j i = j
} }
result[keyName] = value
// Skip trailing whitespace and optional comma before next pair // Skip trailing whitespace and optional comma before next pair
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') { for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++ i++
@@ -463,9 +440,9 @@ func parseJSONStringRunes(runes []rune, start int) (string, int) {
// jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value. // jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value.
func jsonStringTokenToRawString(token string) string { func jsonStringTokenToRawString(token string) string {
var s string r := gjson.Parse(token)
if errUnmarshal := json.Unmarshal([]byte(token), &s); errUnmarshal == nil { if r.Type == gjson.String {
return s return r.String()
} }
// Fallback: strip surrounding quotes if present // Fallback: strip surrounding quotes if present
if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' { if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' {
@@ -579,7 +556,7 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
} }
} }
var parts []interface{} partIndex := 0
// Handle reasoning content before visible text // Handle reasoning content before visible text
if reasoning := message.Get("reasoning_content"); reasoning.Exists() { if reasoning := message.Get("reasoning_content"); reasoning.Exists() {
@@ -587,18 +564,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
parts = append(parts, map[string]interface{}{ out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.thought", partIndex), true)
"thought": true, out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), reasoningText)
"text": reasoningText, partIndex++
})
} }
} }
// Handle content first // Handle content first
if content := message.Get("content"); content.Exists() && content.String() != "" { if content := message.Get("content"); content.Exists() && content.String() != "" {
parts = append(parts, map[string]interface{}{ out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), content.String())
"text": content.String(), partIndex++
})
} }
// Handle tool calls // Handle tool calls
@@ -609,27 +584,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
functionName := function.Get("name").String() functionName := function.Get("name").String()
functionArgs := function.Get("arguments").String() functionArgs := function.Get("arguments").String()
// Parse arguments namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
var argsMap map[string]interface{} argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
argsMap = parseArgsToMap(functionArgs) out, _ = sjson.Set(out, namePath, functionName)
out, _ = sjson.SetRaw(out, argsPath, parseArgsToObjectRaw(functionArgs))
functionCallPart := map[string]interface{}{ partIndex++
"functionCall": map[string]interface{}{
"name": functionName,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
} }
return true return true
}) })
} }
// Set parts
if len(parts) > 0 {
out, _ = sjson.Set(out, "candidates.0.content.parts", parts)
}
// Handle finish reason // Handle finish reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String()) geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
@@ -645,15 +609,12 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
// Handle usage information // Handle usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ out, _ = sjson.Set(out, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), out, _ = sjson.Set(out, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), out, _ = sjson.Set(out, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens out, _ = sjson.Set(out, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
out, _ = sjson.Set(out, "usageMetadata", usageObj)
} }
return out return out

View File

@@ -484,16 +484,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
} }
// Build response.output using aggregated buffers // Build response.output using aggregated buffers
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if st.ReasoningBuf.Len() > 0 { if st.ReasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningID, item, _ = sjson.Set(item, "id", st.ReasoningID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"type": "summary_text",
"text": st.ReasoningBuf.String(),
}},
})
} }
// Append message items in ascending index order // Append message items in ascending index order
if len(st.MsgItemAdded) > 0 { if len(st.MsgItemAdded) > 0 {
@@ -513,18 +509,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
if b := st.MsgTextBuf[i]; b != nil { if b := st.MsgTextBuf[i]; b != nil {
txt = b.String() txt = b.String()
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_%d", st.ResponseID, i), item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
"type": "message", item, _ = sjson.Set(item, "content.0.text", txt)
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": txt,
}},
"role": "assistant",
})
} }
} }
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
@@ -547,18 +535,16 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
callID := st.FuncCallIDs[i] callID := st.FuncCallIDs[i]
name := st.FuncNames[i] name := st.FuncNames[i]
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
if st.UsageSeen { if st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens) completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens)
@@ -681,7 +667,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
} }
// Build output list from choices[...] // Build output list from choices[...]
var outputs []interface{} outputsWrapper := `{"arr":[]}`
// Detect and capture reasoning content if present // Detect and capture reasoning content if present
rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String() rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String()
includeReasoning := rcText != "" includeReasoning := rcText != ""
@@ -693,21 +679,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if strings.HasPrefix(rid, "resp_") { if strings.HasPrefix(rid, "resp_") {
rid = strings.TrimPrefix(rid, "resp_") rid = strings.TrimPrefix(rid, "resp_")
} }
reasoningItem := map[string]interface{}{
"id": fmt.Sprintf("rs_%s", rid),
"type": "reasoning",
"encrypted_content": "",
}
// Prefer summary_text from reasoning_content; encrypted_content is optional // Prefer summary_text from reasoning_content; encrypted_content is optional
var summaries []interface{} reasoningItem := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}`
reasoningItem, _ = sjson.Set(reasoningItem, "id", fmt.Sprintf("rs_%s", rid))
if rcText != "" { if rcText != "" {
summaries = append(summaries, map[string]interface{}{ reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.type", "summary_text")
"type": "summary_text", reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.text", rcText)
"text": rcText,
})
} }
reasoningItem["summary"] = summaries outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoningItem)
outputs = append(outputs, reasoningItem)
} }
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
@@ -716,18 +695,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if msg.Exists() { if msg.Exists() {
// Text message part // Text message part
if c := msg.Get("content"); c.Exists() && c.String() != "" { if c := msg.Get("content"); c.Exists() && c.String() != "" {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())), item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())))
"type": "message", item, _ = sjson.Set(item, "content.0.text", c.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": c.String(),
}},
"role": "assistant",
})
} }
// Function/tool calls // Function/tool calls
@@ -736,14 +707,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
callID := tc.Get("id").String() callID := tc.Get("id").String()
name := tc.Get("function.name").String() name := tc.Get("function.name").String()
args := tc.Get("function.arguments").String() args := tc.Get("function.arguments").String()
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
})
return true return true
}) })
} }
@@ -751,8 +720,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
return true return true
}) })
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
resp, _ = sjson.Set(resp, "output", outputs) resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// usage mapping // usage mapping

View File

@@ -6,6 +6,7 @@ package usage
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -90,7 +91,7 @@ type modelStats struct {
type RequestDetail struct { type RequestDetail struct {
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Source string `json:"source"` Source string `json:"source"`
AuthIndex uint64 `json:"auth_index"` AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"` Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"` Failed bool `json:"failed"`
} }
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
return result return result
} }
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string { func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil { if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
return tokens return tokens
} }
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string { func formatHour(hour int) string {
if hour < 0 { if hour < 0 {
hour = 0 hour = 0

View File

@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
} }
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas. // addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas. // Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string { func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields // Find all "type" fields
paths := findPaths(jsonStr, "type") paths := findPaths(jsonStr, "type")
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Check if properties exists and is empty or missing // Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties") propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath) propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false needsPlaceholder := false
if !propsVal.Exists() { if !propsVal.Exists() {
@@ -381,8 +384,17 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool") jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array // Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"}) jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
} }
} }

View File

@@ -614,71 +614,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
// propertyNames is used to validate object property names (e.g., must match a pattern)
// Gemini doesn't support this keyword and will reject requests containing it
input := `{
"type": "object",
"properties": {
"metadata": {
"type": "object",
"propertyNames": {
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"additionalProperties": {
"type": "string"
}
}
}
}`
expected := `{
"type": "object",
"properties": {
"metadata": {
"type": "object"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
// Verify propertyNames is completely removed
if strings.Contains(result, "propertyNames") {
t.Errorf("propertyNames keyword should be removed, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"config": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
}
}
}
}`
result := CleanJSONSchemaForGemini(input)
if strings.Contains(result, "propertyNames") {
t.Errorf("Nested propertyNames should be removed, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) { func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{} var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap) errExp := json.Unmarshal([]byte(expectedJSON), &expMap)

View File

@@ -7,10 +7,11 @@ import (
) )
const ( const (
ThinkingBudgetMetadataKey = "thinking_budget" ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts" ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort" ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model" ThinkingOriginalModelMetadataKey = "thinking_original_model"
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
) )
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns // NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
} }
if metadata != nil { if metadata != nil {
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {
return base
}
}
}
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok { if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" { if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" { if base := normalize(s); base != "" {

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"os" "os"
"reflect"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
} }
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
log.Infof("config successfully reloaded, triggering client reload") log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)

View File

@@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) { if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i)) changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
} }
oldModels := SummarizeGeminiModels(o.Models)
newModels := SummarizeGeminiModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels) oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash { if oldExcluded.hash != newExcluded.hash {
@@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) { if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i)) changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
} }
oldModels := SummarizeClaudeModels(o.Models)
newModels := SummarizeClaudeModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels) oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash { if oldExcluded.hash != newExcluded.hash {
@@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) { if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i)) changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
} }
oldModels := SummarizeCodexModels(o.Models)
newModels := SummarizeCodexModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels) oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels) newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash { if oldExcluded.hash != newExcluded.hash {
@@ -185,10 +200,18 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings { if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings)) changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
} }
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
}
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 { if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...) changes = append(changes, entries...)
} }
if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key) // Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote { if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
@@ -301,3 +324,43 @@ func formatProxyURL(raw string) string {
} }
return scheme + "://" + host return scheme + "://" + host
} }
func equalStringSet(a, b []string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
aSet := make(map[string]struct{}, len(a))
for _, k := range a {
aSet[strings.TrimSpace(k)] = struct{}{}
}
bSet := make(map[string]struct{}, len(b))
for _, k := range b {
bSet[strings.TrimSpace(k)] = struct{}{}
}
if len(aSet) != len(bSet) {
return false
}
for k := range aSet {
if _, ok := bSet[k]; !ok {
return false
}
}
return true
}
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
// Comparison is done by count and content (upstream key and client keys).
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
return false
}
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
return false
}
}
return true
}

View File

@@ -56,6 +56,36 @@ func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
return hashJoined(keys) return hashJoined(keys)
} }
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
func ComputeCodexModelsHash(models []config.CodexModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists. // ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
func ComputeExcludedModelsHash(excluded []string) string { func ComputeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 { if len(excluded) == 0 {

View File

@@ -81,6 +81,15 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) {
} }
} }
func TestComputeCodexModelsHash_Empty(t *testing.T) {
if got := ComputeCodexModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) { func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.ClaudeModel{ a := []config.ClaudeModel{
{Name: "m1", Alias: "a1"}, {Name: "m1", Alias: "a1"},
@@ -95,6 +104,20 @@ func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
} }
} }
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.CodexModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.CodexModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) { func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"}) hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"}) hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
@@ -157,3 +180,15 @@ func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
t.Fatalf("expected different hash when models change, got %s", h3) t.Fatalf("expected different hash when models change, got %s", h3)
} }
} }
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeCodexModelsHash(models)
h2 := ComputeCodexModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}

View File

@@ -0,0 +1,121 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type GeminiModelsSummary struct {
hash string
count int
}
type ClaudeModelsSummary struct {
hash string
count int
}
type CodexModelsSummary struct {
hash string
count int
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
if len(models) == 0 {
return GeminiModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return GeminiModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeClaudeModels hashes Claude model aliases for change detection.
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
if len(models) == 0 {
return ClaudeModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return ClaudeModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeCodexModels hashes Codex model aliases for change detection.
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
if len(models) == 0 {
return CodexModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return CodexModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin
count: len(entries), count: len(entries),
} }
} }
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeVertexModels hashes vertex-compatible models for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, m := range models {
name := strings.TrimSpace(m.Name)
alias := strings.TrimSpace(m.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -0,0 +1,98 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type OAuthModelMappingsSummary struct {
hash string
count int
}
// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel.
func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]OAuthModelMappingsSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeOAuthModelMappingList(v)
}
if len(out) == 0 {
return nil
}
return out
}
// DiffOAuthModelMappingChanges compares OAuth model mappings maps.
func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) {
oldSummary := SummarizeOAuthModelMappings(oldMap)
newSummary := SummarizeOAuthModelMappings(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary {
if len(list) == 0 {
return OAuthModelMappingsSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, mapping := range list {
name := strings.ToLower(strings.TrimSpace(mapping.Name))
alias := strings.ToLower(strings.TrimSpace(mapping.Alias))
if name == "" || alias == "" {
continue
}
key := name + "->" + alias
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
normalized = append(normalized, key)
}
if len(normalized) == 0 {
return OAuthModelMappingsSummary{}
}
sort.Strings(normalized)
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
return OAuthModelMappingsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(normalized),
}
}

View File

@@ -66,6 +66,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
if base != "" { if base != "" {
attrs["base_url"] = base attrs["base_url"] = base
} }
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(entry.Headers, attrs) addConfigHeadersToAttrs(entry.Headers, attrs)
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: id, ID: id,
@@ -151,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs) addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL) proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{ a := &coreauth.Auth{

View File

@@ -14,7 +14,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
// - c: The Gin context for the request. // - c: The Gin context for the request.
// - rawJSON: The raw JSON request body. // - rawJSON: The raw JSON request body.
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Set up Server-Sent Events (SSE) headers for streaming response
// These headers are essential for maintaining a persistent connection
// and enabling real-time streaming of chat completions
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks // This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
@@ -213,58 +204,82 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) setSSEHeaders := func() {
return c.Header("Content-Type", "text/event-stream")
} c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { // Peek at the first chunk to determine success or failure before setting headers
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
// This guarantees clients see incremental output even for small responses.
for { for {
select { select {
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
cancel(c.Request.Context().Err()) cliCancel(c.Request.Context().Err())
return return
case errMsg, ok := <-errChan:
case chunk, ok := <-data:
if !ok { if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
flusher.Flush() flusher.Flush()
cancel(nil) cliCancel(nil)
return return
} }
// Success! Set headers now.
setSSEHeaders()
// Write the first chunk
if len(chunk) > 0 { if len(chunk) > 0 {
_, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write(chunk)
flusher.Flush() flusher.Flush()
} }
case errMsg, ok := <-errs: // Continue streaming the rest
if !ok { h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
continue
}
if errMsg != nil {
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
c.Status(status)
// An error occurred: emit as a proper SSE error event
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return return
case <-time.After(500 * time.Millisecond):
} }
} }
} }
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if len(chunk) == 0 {
return
}
_, _ = c.Writer.Write(chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
c.Status(status)
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
},
})
}
type claudeErrorDetail struct { type claudeErrorDetail struct {
Type string `json:"type"` Type string `json:"type"`
Message string `json:"message"` Message string `json:"message"`

View File

@@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
} }
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for { var keepAliveInterval *time.Duration
select { if alt != "" {
case <-c.Request.Context().Done(): disabled := time.Duration(0)
cancel(c.Request.Context().Err()) keepAliveInterval = &disabled
return }
case chunk, ok := <-data:
if !ok { h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
cancel(nil) KeepAliveInterval: keepAliveInterval,
return WriteChunk: func(chunk []byte) {
}
if alt == "" { if alt == "" {
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) { if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
continue return
} }
if !bytes.HasPrefix(chunk, []byte("data:")) { if !bytes.HasPrefix(chunk, []byte("data:")) {
@@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus
} else { } else {
_, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write(chunk)
} }
flusher.Flush() },
case errMsg, ok := <-errs: WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if !ok { if errMsg == nil {
continue return
} }
if errMsg != nil { status := http.StatusInternalServerError
h.WriteErrorResponse(c, errMsg) if errMsg.StatusCode > 0 {
flusher.Flush() status = errMsg.StatusCode
} }
var execErr error errText := http.StatusText(status)
if errMsg != nil { if errMsg.Error != nil && errMsg.Error.Error() != "" {
execErr = errMsg.Error errText = errMsg.Error.Error()
} }
cancel(execErr) body := handlers.BuildErrorResponseBody(status, errText)
return if alt == "" {
case <-time.After(500 * time.Millisecond): _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} } else {
} _, _ = c.Writer.Write(body)
}
},
})
} }

View File

@@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) { func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
alt := h.GetAlt(c) alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
@@ -247,8 +240,65 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Closed without data
if alt == "" {
setSSEHeaders()
}
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers.
if alt == "" {
setSSEHeaders()
}
// Write first chunk
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
// Continue
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
} }
// handleCountTokens handles token counting requests for Gemini models. // handleCountTokens handles token counting requests for Gemini models.
@@ -297,16 +347,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
} }
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for { var keepAliveInterval *time.Duration
select { if alt != "" {
case <-c.Request.Context().Done(): disabled := time.Duration(0)
cancel(c.Request.Context().Err()) keepAliveInterval = &disabled
return }
case chunk, ok := <-data:
if !ok { h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
cancel(nil) KeepAliveInterval: keepAliveInterval,
return WriteChunk: func(chunk []byte) {
}
if alt == "" { if alt == "" {
_, _ = c.Writer.Write([]byte("data: ")) _, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write(chunk)
@@ -314,22 +363,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus
} else { } else {
_, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write(chunk)
} }
flusher.Flush() },
case errMsg, ok := <-errs: WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if !ok { if errMsg == nil {
continue return
} }
if errMsg != nil { status := http.StatusInternalServerError
h.WriteErrorResponse(c, errMsg) if errMsg.StatusCode > 0 {
flusher.Flush() status = errMsg.StatusCode
} }
var execErr error errText := http.StatusText(status)
if errMsg != nil { if errMsg.Error != nil && errMsg.Error.Error() != "" {
execErr = errMsg.Error errText = errMsg.Error.Error()
} }
cancel(execErr) body := handlers.BuildErrorResponseBody(status, errText)
return if alt == "" {
case <-time.After(500 * time.Millisecond): _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} } else {
} _, _ = c.Writer.Write(body)
}
},
})
} }

View File

@@ -9,9 +9,12 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -40,6 +43,117 @@ type ErrorDetail struct {
Code string `json:"code,omitempty"` Code string `json:"code,omitempty"`
} }
const idempotencyKeyMetadataKey = "idempotency_key"
const (
defaultStreamingKeepAliveSeconds = 0
defaultStreamingBootstrapRetries = 0
)
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if strings.TrimSpace(errText) == "" {
errText = http.StatusText(status)
}
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
var code string
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
code = "invalid_api_key"
case http.StatusForbidden:
errType = "permission_error"
code = "insufficient_quota"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
code = "rate_limit_exceeded"
case http.StatusNotFound:
errType = "invalid_request_error"
code = "model_not_found"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
code = "internal_server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
Code: code,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
}
return payload
}
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
// Returning 0 disables keep-alives (default when unset).
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds
if cfg != nil {
seconds = cfg.Streaming.KeepAliveSeconds
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries
if cfg != nil {
retries = cfg.Streaming.BootstrapRetries
}
if retries < 0 {
retries = 0
}
return retries
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
}
}
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
}
func mergeMetadata(base, overlay map[string]any) map[string]any {
if len(base) == 0 && len(overlay) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(overlay))
for k, v := range base {
out[k] = v
}
for k, v := range overlay {
out[k] = v
}
return out
}
// BaseAPIHandler contains the handlers for API endpoints. // BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages // It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration. // load balancing, client selection, and configuration.
@@ -104,13 +218,39 @@ func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
// Parameters: // Parameters:
// - handler: The API handler associated with the request. // - handler: The API handler associated with the request.
// - c: The Gin context of the current request. // - c: The Gin context of the current request.
// - ctx: The parent context. // - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID).
// //
// Returns: // Returns:
// - context.Context: The new context with cancellation and embedded values. // - context.Context: The new context with cancellation and embedded values.
// - APIHandlerCancelFunc: A function to cancel the context and log the response. // - APIHandlerCancelFunc: A function to cancel the context and log the response.
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) { func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
newCtx, cancel := context.WithCancel(ctx) parentCtx := ctx
if parentCtx == nil {
parentCtx = context.Background()
}
var requestCtx context.Context
if c != nil && c.Request != nil {
requestCtx = c.Request.Context()
}
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
newCtx, cancel := context.WithCancel(parentCtx)
if requestCtx != nil && requestCtx != parentCtx {
go func() {
select {
case <-requestCtx.Done():
cancel()
case <-newCtx.Done():
}
}()
}
newCtx = context.WithValue(newCtx, "gin", c) newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler) newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) { return newCtx, func(params ...interface{}) {
@@ -183,6 +323,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
if errMsg != nil { if errMsg != nil {
return nil, errMsg return nil, errMsg
} }
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -196,9 +337,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
} }
if cloned := cloneMetadata(metadata); cloned != nil { opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
opts.Metadata = cloned
}
resp, err := h.AuthManager.Execute(ctx, providers, req, opts) resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil { if err != nil {
status := http.StatusInternalServerError status := http.StatusInternalServerError
@@ -225,6 +364,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
if errMsg != nil { if errMsg != nil {
return nil, errMsg return nil, errMsg
} }
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -238,9 +378,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
} }
if cloned := cloneMetadata(metadata); cloned != nil { opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
opts.Metadata = cloned
}
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil { if err != nil {
status := http.StatusInternalServerError status := http.StatusInternalServerError
@@ -270,6 +408,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
close(errChan) close(errChan)
return nil, errChan return nil, errChan
} }
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -283,9 +422,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
} }
if cloned := cloneMetadata(metadata); cloned != nil { opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
opts.Metadata = cloned
}
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil { if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1) errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -310,31 +447,94 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
go func() { go func() {
defer close(dataChan) defer close(dataChan)
defer close(errChan) defer close(errChan)
for chunk := range chunks { sentPayload := false
if chunk.Err != nil { bootstrapRetries := 0
status := http.StatusInternalServerError maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 { bootstrapEligible := func(err error) bool {
status = code status := statusFromError(err)
} if status == 0 {
} return true
var addon http.Header
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
return
} }
if len(chunk.Payload) > 0 { switch status {
dataChan <- cloneBytes(chunk.Payload) case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
http.StatusRequestTimeout, http.StatusTooManyRequests:
return true
default:
return status >= http.StatusInternalServerError
}
}
outer:
for {
for {
var chunk coreexecutor.StreamChunk
var ok bool
if ctx != nil {
select {
case <-ctx.Done():
return
case chunk, ok = <-chunks:
}
} else {
chunk, ok = <-chunks
}
if !ok {
return
}
if chunk.Err != nil {
streamErr := chunk.Err
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
chunks = retryChunks
continue outer
}
streamErr = retryErr
}
}
status := http.StatusInternalServerError
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
return
}
if len(chunk.Payload) > 0 {
sentPayload = true
dataChan <- cloneBytes(chunk.Payload)
}
} }
} }
}() }()
return dataChan, errChan return dataChan, errChan
} }
func statusFromError(err error) int {
if err == nil {
return 0
}
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
return code
}
}
return 0
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
// Resolve "auto" model to an actual available model first // Resolve "auto" model to an actual available model first
resolvedModelName := util.ResolveAutoModel(modelName) resolvedModelName := util.ResolveAutoModel(modelName)
@@ -418,39 +618,23 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
} }
} }
// Prefer preserving upstream JSON error bodies when possible. body := BuildErrorResponseBody(status, errText)
buildJSONBody := func() []byte { // Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
trimmed := strings.TrimSpace(errText) var previous []byte
if trimmed != "" && json.Valid([]byte(trimmed)) { if existing, exists := c.Get("API_RESPONSE"); exists {
return []byte(trimmed) if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
previous = bytes.Clone(existingBytes)
}
}
appendAPIResponse(c, body)
trimmedErrText := strings.TrimSpace(errText)
trimmedBody := bytes.TrimSpace(body)
if len(previous) > 0 {
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
c.Set("API_RESPONSE", previous)
} }
errType := "invalid_request_error"
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
case http.StatusForbidden:
errType = "permission_error"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
}
return payload
} }
body := buildJSONBody()
c.Set("API_RESPONSE", bytes.Clone(body))
if !c.Writer.Written() { if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")

View File

@@ -0,0 +1,124 @@
package handlers
import (
"context"
"net/http"
"sync"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type failOnceStreamExecutor struct {
mu sync.Mutex
calls int
}
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
e.mu.Lock()
e.calls++
call := e.calls
e.mu.Unlock()
ch := make(chan coreexecutor.StreamChunk, 1)
if call == 1 {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return ch, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
}
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *failOnceStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
}

View File

@@ -11,12 +11,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time" "sync"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -109,7 +110,17 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
// Check if the client requested a streaming response. // Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream") streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True { stream := streamResult.Type == gjson.True
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
// Convert them to Chat Completions so downstream translators preserve tool metadata.
if shouldTreatAsResponsesFormat(rawJSON) {
modelName := gjson.GetBytes(rawJSON, "model").String()
rawJSON = responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
stream = gjson.GetBytes(rawJSON, "stream").Bool()
}
if stream {
h.handleStreamingResponse(c, rawJSON) h.handleStreamingResponse(c, rawJSON)
} else { } else {
h.handleNonStreamingResponse(c, rawJSON) h.handleNonStreamingResponse(c, rawJSON)
@@ -117,6 +128,21 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
} }
// shouldTreatAsResponsesFormat detects OpenAI Responses-style payloads that are
// accidentally sent to the Chat Completions endpoint.
func shouldTreatAsResponsesFormat(rawJSON []byte) bool {
if gjson.GetBytes(rawJSON, "messages").Exists() {
return false
}
if gjson.GetBytes(rawJSON, "input").Exists() {
return true
}
if gjson.GetBytes(rawJSON, "instructions").Exists() {
return true
}
return false
}
// Completions handles the /v1/completions endpoint. // Completions handles the /v1/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response // It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider. // and calls the appropriate handler based on the model provider.
@@ -417,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
// - c: The Gin context containing the HTTP request and response // - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request // - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
@@ -437,7 +458,55 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
modelName := gjson.GetBytes(rawJSON, "model").String() modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk to determine success or failure before setting headers
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
return
}
// Success! Commit to streaming headers.
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
// Continue streaming the rest
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
} }
// handleCompletionsNonStreamingResponse handles non-streaming completions responses. // handleCompletionsNonStreamingResponse handles non-streaming completions responses.
@@ -474,11 +543,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
// - c: The Gin context containing the HTTP request and response // - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request // - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) { func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
@@ -498,71 +562,109 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for { for {
select { select {
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err()) cliCancel(c.Request.Context().Err())
return return
case chunk, isOk := <-dataChan: case errMsg, ok := <-errChan:
if !isOk { if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush() flusher.Flush()
cliCancel() cliCancel(nil)
return return
} }
// Success! Set headers.
setSSEHeaders()
// Write the first chunk
converted := convertChatCompletionsStreamChunkToCompletions(chunk) converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted != nil { if converted != nil {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
flusher.Flush() flusher.Flush()
} }
case errMsg, isOk := <-errChan:
if !isOk { done := make(chan struct{})
continue var doneOnce sync.Once
} stop := func() { doneOnce.Do(func() { close(done) }) }
if errMsg != nil {
h.WriteErrorResponse(c, errMsg) convertedChan := make(chan []byte)
flusher.Flush() go func() {
} defer close(convertedChan)
var execErr error for {
if errMsg != nil { select {
execErr = errMsg.Error case <-done:
} return
cliCancel(execErr) case chunk, ok := <-dataChan:
if !ok {
return
}
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted == nil {
continue
}
select {
case <-done:
return
case convertedChan <- converted:
}
}
}
}()
h.handleStreamResult(c, flusher, func(err error) {
stop()
cliCancel(err)
}, convertedChan, errChan)
return return
case <-time.After(500 * time.Millisecond):
} }
} }
} }
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for { h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
select { WriteChunk: func(chunk []byte) {
case <-c.Request.Context().Done(): _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
cancel(c.Request.Context().Err()) },
return WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
case chunk, ok := <-data: if errMsg == nil {
if !ok {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cancel(nil)
return return
} }
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) status := http.StatusInternalServerError
flusher.Flush() if errMsg.StatusCode > 0 {
case errMsg, ok := <-errs: status = errMsg.StatusCode
if !ok {
continue
} }
if errMsg != nil { errText := http.StatusText(status)
h.WriteErrorResponse(c, errMsg) if errMsg.Error != nil && errMsg.Error.Error() != "" {
flusher.Flush() errText = errMsg.Error.Error()
} }
var execErr error body := handlers.BuildErrorResponseBody(status, errText)
if errMsg != nil { _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
execErr = errMsg.Error },
} WriteDone: func() {
cancel(execErr) _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
return },
case <-time.After(500 * time.Millisecond): })
}
}
} }

View File

@@ -11,7 +11,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
// - c: The Gin context containing the HTTP request and response // - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request // - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
@@ -149,46 +143,88 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
modelName := gjson.GetBytes(rawJSON, "model").String() modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for { for {
select { select {
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
cancel(c.Request.Context().Err()) cliCancel(c.Request.Context().Err())
return return
case chunk, ok := <-data: case errMsg, ok := <-errChan:
if !ok { if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send headers and done.
setSSEHeaders()
_, _ = c.Writer.Write([]byte("\n")) _, _ = c.Writer.Write([]byte("\n"))
flusher.Flush() flusher.Flush()
cancel(nil) cliCancel(nil)
return return
} }
// Success! Set headers.
setSSEHeaders()
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) { if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n")) _, _ = c.Writer.Write([]byte("\n"))
} }
_, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n")) _, _ = c.Writer.Write([]byte("\n"))
flusher.Flush() flusher.Flush()
case errMsg, ok := <-errs:
if !ok { // Continue
continue h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return return
case <-time.After(500 * time.Millisecond):
} }
} }
} }
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))
},
})
}

View File

@@ -0,0 +1,121 @@
package handlers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
)
type StreamForwardOptions struct {
// KeepAliveInterval overrides the configured streaming keep-alive interval.
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
KeepAliveInterval *time.Duration
// WriteChunk writes a single data chunk to the response body. It should not flush.
WriteChunk func(chunk []byte)
// WriteTerminalError writes an error payload to the response body when streaming fails
// after headers have already been committed. It should not flush.
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
// WriteDone optionally writes a terminal marker when the upstream data channel closes
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
WriteDone func()
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
// When nil, a standard SSE comment heartbeat is used.
WriteKeepAlive func()
}
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
if c == nil {
return
}
if cancel == nil {
return
}
writeChunk := opts.WriteChunk
if writeChunk == nil {
writeChunk = func([]byte) {}
}
writeKeepAlive := opts.WriteKeepAlive
if writeKeepAlive == nil {
writeKeepAlive = func() {
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
}
}
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
if opts.KeepAliveInterval != nil {
keepAliveInterval = *opts.KeepAliveInterval
}
var keepAlive *time.Ticker
var keepAliveC <-chan time.Time
if keepAliveInterval > 0 {
keepAlive = time.NewTicker(keepAliveInterval)
defer keepAlive.Stop()
keepAliveC = keepAlive.C
}
var terminalErr *interfaces.ErrorMessage
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
// Prefer surfacing a terminal error if one is pending.
if terminalErr == nil {
select {
case errMsg, ok := <-errs:
if ok && errMsg != nil {
terminalErr = errMsg
}
default:
}
}
if terminalErr != nil {
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(terminalErr)
}
flusher.Flush()
cancel(terminalErr.Error)
return
}
if opts.WriteDone != nil {
opts.WriteDone()
}
flusher.Flush()
cancel(nil)
return
}
writeChunk(chunk)
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
terminalErr = errMsg
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(errMsg)
flusher.Flush()
}
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-keepAliveC:
writeKeepAlive()
flusher.Flush()
}
}
}

67
sdk/api/management.go Normal file
View File

@@ -0,0 +1,67 @@
// Package api exposes helpers for embedding CLIProxyAPI.
//
// It wraps internal management handler types so external projects can integrate
// management endpoints without importing internal packages.
package api
import (
"github.com/gin-gonic/gin"
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
type ManagementTokenRequester interface {
RequestAnthropicToken(*gin.Context)
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
GetAuthStatus(c *gin.Context)
}
type managementTokenRequester struct {
handler *internalmanagement.Handler
}
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
return &managementTokenRequester{
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
}
}
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
m.handler.RequestAnthropicToken(c)
}
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
m.handler.RequestGeminiCLIToken(c)
}
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
m.handler.RequestCodexToken(c)
}
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
m.handler.RequestIFlowToken(c)
}
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
m.handler.RequestIFlowCookieToken(c)
}
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
m.handler.GetAuthStatus(c)
}

Some files were not shown because too many files have changed in this diff Show More