Compare commits

...

134 Commits

Author SHA1 Message Date
Luis Pater
9116392a45 Merge branch 'router-for-me:main' into main 2026-01-02 20:49:51 +08:00
Luis Pater
c8b33a8cc3 Merge pull request #824 from router-for-me/script
feat(script): add usage statistics preservation across container rebuilds
2026-01-02 20:42:25 +08:00
Luis Pater
b9d1e70ac2 Merge pull request #830 from router-for-me/gemini
fix(util): disable default thinking for gemini-3 series
2026-01-02 10:59:24 +08:00
hkfires
fdf5720217 fix(gemini): remove default thinking for gemini 3 models 2026-01-02 10:55:59 +08:00
hkfires
f40bd0cd51 feat(script): add usage statistics preservation across container rebuilds 2026-01-02 10:01:20 +08:00
hkfires
e33676bb87 fix(util): disable default thinking for gemini-3 series 2026-01-02 09:43:40 +08:00
Luis Pater
b1f1cee1e5 feat(executor): refine payload handling by integrating original request context
Updated `applyPayloadConfig` to `applyPayloadConfigWithRoot` across payload translation logic, enabling validation against the original request payload when available. Added support for improved model normalization and translation consistency.
2026-01-02 03:28:37 +08:00
Luis Pater
a1ecc9ab00 Merge branch 'router-for-me:main' into main 2026-01-02 00:04:50 +08:00
Luis Pater
2a663d5cba feat(executor): enhance payload translation with original request context
Refactored `applyPayloadConfig` to `applyPayloadConfigWithRoot`, adding support for default rule validation against the original payload when available. Updated all executors to use `applyPayloadConfigWithRoot` and incorporate an optional original request payload for translations.
2026-01-02 00:03:26 +08:00
Luis Pater
ba486ca6b7 Merge branch 'router-for-me:main' into main 2026-01-01 21:03:16 +08:00
Luis Pater
750b930679 Merge pull request #823 from router-for-me/translator
feat(translator): enhance Claude-to-OpenAI conversion with thinking block and tool result handling
2026-01-01 20:16:10 +08:00
hkfires
3902fd7501 fix(iflow): remove thinking field from request body in thinking config handler 2026-01-01 19:40:28 +08:00
hkfires
4fc3d5e935 refactor(iflow): simplify thinking config handling for GLM and MiniMax models 2026-01-01 19:31:08 +08:00
hkfires
2d2f4572a7 fix(translator): remove unnecessary whitespace trimming in reasoning text collection 2026-01-01 12:39:09 +08:00
hkfires
8f4c46f38d fix(translator): emit tool_result messages before user content in Claude-to-OpenAI conversion 2026-01-01 11:11:43 +08:00
hkfires
b6ba51bc2a feat(translator): add thinking block and tool result handling for Claude-to-OpenAI conversion 2026-01-01 09:41:25 +08:00
Luis Pater
6a66d32d37 Merge pull request #803 from HsnSaboor/fix-invalid-function-names-sanitization-v2
feat(translator): resolve invalid function name errors by sanitizing Claude tool names
2026-01-01 01:15:50 +08:00
Luis Pater
8d15723195 feat(registry): add GetAvailableModelsByProvider method for retrieving models by provider 2025-12-31 23:37:46 +08:00
Chén Mù
736e0aae86 Merge pull request #814 from router-for-me/aistudio
Fix model alias thinking suffix
2025-12-31 03:08:05 -08:00
hkfires
8bf3305b2b fix(thinking): fallback to upstream model for thinking support when alias not in registry 2025-12-31 18:07:13 +08:00
hkfires
d00e3ea973 feat(thinking): add numeric budget to thinkingLevel conversion fallback 2025-12-31 17:14:47 +08:00
hkfires
89db4e9481 fix(thinking): use model alias for thinking config resolution in mapped models 2025-12-31 17:09:22 +08:00
hkfires
e332419081 feat(registry): add thinking support for gemini-2.5-computer-use-preview model 2025-12-31 17:09:22 +08:00
Luis Pater
19232c6388 Merge branch 'router-for-me:main' into main 2025-12-31 11:52:49 +08:00
Luis Pater
e998b1229a feat(updater): add fallback URL and logic for missing management asset 2025-12-31 11:51:20 +08:00
Luis Pater
bbed134bd1 feat(api): add GetAuthStatus method to ManagementTokenRequester interface 2025-12-31 09:40:48 +08:00
Saboor Hassan
47b9503112 chore: revert changes to internal/translator to comply with path guard
This commit reverts all modifications within internal/translator. A separate issue
will be created for the maintenance team to integrate SanitizeFunctionName into
the translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:19:26 +05:00
Saboor Hassan
3b9253c2be fix(translator): resolve invalid function name errors by sanitizing Claude tool names
This commit centralizes tool name sanitization in SanitizeFunctionName,
applying character compliance, starting character rules, and length limits.
It also fixes a regression in gemini_schema tests and preserves MCP-specific
shortening logic while ensuring compliance.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:14:46 +05:00
Saboor Hassan
d241359153 fix(translator): address PR feedback for tool name sanitization
- Pre-compile sanitization regex for better performance.
- Optimize SanitizeFunctionName for conciseness and correctness.
- Handle 64-char edge cases by truncating before prepending underscore.
- Fix bug in Antigravity translator (incorrect join index).
- Refactor Gemini translators to avoid redundant sanitization calls.
- Add comprehensive unit tests including 64-char edge cases.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:54:41 +05:00
Saboor Hassan
f4d4249ba5 feat(translator): sanitize tool/function names for upstream provider compatibility
Implemented SanitizeFunctionName utility to ensure Claude tool names meet
Gemini/Upstream strict naming conventions (alphanumeric, starts with letter/underscore, max 64 chars).
Applied sanitization to tool definitions and usage in all relevant translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:41:07 +05:00
Luis Pater
06075c0e93 Merge pull request #75 from router-for-me/plus
v6.6.71
2025-12-30 23:34:44 +08:00
Luis Pater
e85c9d9322 Merge branch 'main' into plus 2025-12-30 23:33:35 +08:00
Chén Mù
cb56cb250e Merge pull request #800 from router-for-me/modelmappings
feat(watcher): add model mappings change detection
2025-12-30 06:50:42 -08:00
hkfires
e0381a6ae0 refactor(watcher): extract model summary functions to dedicated file 2025-12-30 22:39:12 +08:00
hkfires
2c01b2ef64 feat(watcher): add Gemini models and OAuth model mappings change detection 2025-12-30 22:39:12 +08:00
Chén Mù
e947266743 Merge pull request #795 from router-for-me/modelmappings
refactor(executor): resolve upstream model at conductor level before execution
2025-12-30 05:31:19 -08:00
Luis Pater
c6b0e85b54 Fixed: #790
fix(gemini): include full text in response output events
2025-12-30 20:44:13 +08:00
hkfires
26efbed05c refactor(executor): remove redundant upstream model parameter from translateRequest 2025-12-30 20:20:42 +08:00
hkfires
96340bf136 refactor(executor): resolve upstream model at conductor level before execution 2025-12-30 19:31:54 +08:00
hkfires
b055e00c1a fix(executor): use upstream model for thinking config and payload translation 2025-12-30 17:49:44 +08:00
Chén Mù
857c880f99 Merge pull request #785 from router-for-me/gemini
feat(gemini): add per-key model alias support for Gemini provider
2025-12-29 23:32:40 -08:00
hkfires
ce7474d953 feat(cliproxy): propagate thinking support metadata to aliased models 2025-12-30 15:16:54 +08:00
hkfires
70fdd70b84 refactor(cliproxy): extract generic buildConfigModels function for model info generation 2025-12-30 13:35:22 +08:00
hkfires
08ab6a7d77 feat(gemini): add per-key model alias support for Gemini provider 2025-12-30 13:27:57 +08:00
Luis Pater
5c95129884 Merge branch 'router-for-me:main' into main 2025-12-30 11:48:29 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
1a99cfded4 Merge branch 'router-for-me:main' into main 2025-12-30 00:38:41 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
Luis Pater
cf369d4684 Merge branch 'router-for-me:main' into main 2025-12-29 22:41:44 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
e3d8d726e6 Merge branch 'router-for-me:main' into main 2025-12-28 15:09:33 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
0f51e73baa Merge branch 'router-for-me:main' into main 2025-12-28 03:07:58 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
d06e2dc83c Merge branch 'router-for-me:main' into main 2025-12-28 02:10:16 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
Luis Pater
790a17ce98 Merge pull request #70 from router-for-me/plus
v6.6.60
2025-12-28 00:57:14 +08:00
Luis Pater
d473c952fb Merge branch 'main' into plus 2025-12-28 00:56:04 +08:00
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
d35152bbef Merge branch 'router-for-me:main' into main 2025-12-27 22:03:50 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
60936b5185 Merge branch 'router-for-me:main' into main 2025-12-27 03:57:03 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
b7f7b3a1d8 Merge branch 'router-for-me:main' into main 2025-12-27 01:26:33 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
618606966f Merge pull request #68 from router-for-me/plus
v6.6.56
2025-12-26 12:14:42 +08:00
Luis Pater
05f249d77f Merge branch 'main' into plus 2025-12-26 12:14:35 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
9fe6a215e6 Merge branch 'router-for-me:main' into main 2025-12-26 05:10:19 +08:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
cf8b2dcc85 Merge pull request #67 from router-for-me/plus
v6.6.54
2025-12-25 21:07:45 +08:00
Luis Pater
8e24d9dc34 Merge branch 'main' into plus 2025-12-25 21:07:37 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
ed57d82bc1 Merge branch 'router-for-me:main' into main 2025-12-24 23:31:09 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
Luis Pater
7af5a90a0b Merge pull request #65 from router-for-me/plus
v6.6.52
2025-12-24 22:28:45 +08:00
Luis Pater
7551faff79 Merge branch 'main' into plus 2025-12-24 22:27:12 +08:00
Luis Pater
b7409dd2de Merge pull request #706 from router-for-me/log
Log
2025-12-24 22:24:39 +08:00
hkfires
5ba325a8fc refactor(logging): standardize request id formatting and layout 2025-12-24 22:03:07 +08:00
Luis Pater
d502840f91 Merge pull request #695 from NguyenSiTrung/main
feat: add cached token parsing for Gemini , Antigravity API responses
2025-12-24 21:58:55 +08:00
hkfires
99238a4b59 fix(logging): normalize warning level to warn 2025-12-24 21:11:37 +08:00
hkfires
6d43a2ff9a refactor(logging): inline request id in log output 2025-12-24 21:07:18 +08:00
Luis Pater
cdb9c2e6e8 Merge remote-tracking branch 'origin/main' into router-for-me/main 2025-12-24 20:18:53 +08:00
Luis Pater
3faa1ca9af Merge pull request #700 from router-for-me/log
refactor(sdk/auth): rename manager.go to conductor.go
2025-12-24 19:36:24 +08:00
Luis Pater
9d975e0375 feat(models): add support for GLM-4.7 and MiniMax-M2.1 2025-12-24 19:30:57 +08:00
hkfires
2a6d8b78d4 feat(api): add endpoint to retrieve request logs by ID 2025-12-24 19:24:51 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
Luis Pater
6b80ec79a0 Merge pull request #61 from tinyc0der/fix/kiro-tool-results
fix(kiro): Handle tool results correctly in OpenAI format translation
2025-12-24 17:23:41 +08:00
TinyCoder
c169b32570 refactor(kiro): Remove unused variables in OpenAI translator
Remove dead code that was never used:
- toolCallIDToName map: built but never read from
- seenToolCallIDs: declared but never populated, only suppressed with _
2025-12-24 15:10:35 +07:00
TinyCoder
36a512fdf2 fix(kiro): Handle tool results correctly in OpenAI format translation
Fix three issues in Kiro OpenAI translator that caused "Improperly formed request"
errors when processing LiteLLM-translated requests with tool_use/tool_result:

1. Skip merging tool role messages in MergeAdjacentMessages() to preserve
   individual tool_call_id fields

2. Track pendingToolResults and attach to the next user message instead of
   only the last message. Create synthetic user message when conversation
   ends with tool results.

3. Insert synthetic user message with tool results before assistant messages
   to maintain proper alternating user/assistant structure. This fixes the case
   where LiteLLM translates Anthropic user messages containing only tool_result
   blocks into tool role messages followed by assistant.

Adds unit tests covering all tool result handling scenarios.
2025-12-24 15:10:35 +07:00
hkfires
26fbb77901 refactor(sdk/auth): rename manager.go to conductor.go 2025-12-24 15:21:03 +08:00
NguyenSiTrung
a277302262 Merge remote-tracking branch 'upstream/main' 2025-12-24 10:54:09 +07:00
NguyenSiTrung
969c1a5b72 refactor: extract parseGeminiFamilyUsageDetail helper to reduce duplication 2025-12-24 10:22:31 +07:00
NguyenSiTrung
872339bceb feat: add cached token parsing for Gemini API responses 2025-12-24 10:20:11 +07:00
87 changed files with 5429 additions and 729 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/*
README.md
README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE
# Runtime data folders (should be mounted as volumes)
@@ -25,10 +23,14 @@ config.yaml
# Development/editor
bin/*
.claude/*
.vscode/*
.claude/*
.codex/*
.gemini/*
.serena/*
.agent/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug**
A clear and concise description of what the bug is.

11
.gitignore vendored
View File

@@ -12,11 +12,15 @@ bin/*
logs/*
conv/*
temp/*
refs/*
# Storage backends
pgstore/*
gitstore/*
objectstore/*
# Static assets
static/*
refs/*
# Authentication data
auths/*
@@ -30,12 +34,17 @@ GEMINI.md
# Tooling metadata
.vscode/*
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*
.mcp/cache/
# macOS

View File

@@ -434,7 +434,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err)
return
}

View File

@@ -35,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
# Enable debug logging
debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support)
@@ -91,6 +95,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
@@ -106,6 +113,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -123,7 +133,7 @@ ws-auth: false
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
@@ -164,9 +174,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name
# - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro"
# - name: "gemini-2.5-pro"
# alias: "vertex-pro"
# Amp Integration
@@ -175,6 +185,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false)
@@ -184,12 +206,42 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead
# - from: "gpt-5"
# to: "gemini-2.5-pro"
# - from: "claude-3-opus-20240229"
# to: "claude-3-5-sonnet-20241022"
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "claude-sonnet-4-5-20250929"
# to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# oauth-model-mappings:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# OAuth provider excluded models
# oauth-excluded-models:

View File

@@ -5,9 +5,115 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Exit immediately if a command exits with a non-zero status.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
if [[ "${1:-}" == "--with-usage" ]]; then
WITH_USAGE=true
export_stats_api_secret
fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
echo "1) Run using Pre-built Image (Recommended)"
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -38,7 +151,11 @@ case "$choice" in
# Build and start the services with a local-only image tag
export CLI_PROXY_IMAGE="cli-proxy-api:local"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Building the Docker image..."
docker compose build \
--build-arg VERSION="${VERSION}" \
@@ -48,6 +165,11 @@ case "$choice" in
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -55,4 +177,4 @@ case "$choice" in
echo "Invalid choice. Please enter 1 or 2."
exit 1
;;
esac
esac

View File

@@ -0,0 +1,538 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -431,9 +431,46 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path)
}
}
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry
}
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string {
if auth == nil {
return ""

View File

@@ -597,11 +597,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr {
entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
normalizeCodexKey(&entry)
if entry.BaseURL == "" {
continue
}
@@ -613,12 +609,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.CodexModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
@@ -667,12 +664,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
@@ -762,6 +763,32 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized
}
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
@@ -913,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -59,6 +59,11 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
}
}
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }

View File

@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {

View File

@@ -1,12 +1,25 @@
package management
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}
// Check API key change
// Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey
}
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
})
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -18,6 +18,33 @@ import (
log "github.com/sirupsen/logrus"
)
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct {
@@ -48,6 +75,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -1,6 +1,7 @@
package amp
import (
"context"
"errors"
"net"
"net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus"
)
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
}
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass)
}
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil {
ampProviders.Use(auth)
}
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil
}
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
}
}
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
if !cfg.CommercialMode {
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
}
@@ -494,6 +496,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -516,6 +520,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -538,6 +544,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
@@ -564,6 +571,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -866,7 +877,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else {
if oldCfg == nil {

View File

@@ -39,6 +39,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -95,6 +98,14 @@ type Config struct {
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -146,6 +157,13 @@ type RoutingConfig struct {
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID rename mapping for a specific channel.
// It maps the original model name (Name) to the client-visible alias (Alias).
type ModelNameMapping struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
@@ -172,6 +190,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -187,6 +210,17 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
@@ -246,6 +280,9 @@ type ClaudeModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct {
@@ -262,6 +299,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -269,6 +309,18 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct {
@@ -284,6 +336,9 @@ type GeminiKey struct {
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -291,6 +346,18 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
type KiroKey struct {
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
@@ -475,6 +542,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" {
@@ -491,6 +561,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil
}
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// and ensures (From, To) pairs are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
continue
}
seenName := make(map[string]struct{}, len(mappings))
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
aliasKey := strings.ToLower(alias)
if _, ok := seenName[nameKey]; ok {
continue
}
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries.
@@ -876,8 +990,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
}
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended
// to dst at the end, copying their node structure from src.
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
@@ -888,20 +1002,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src)
return
}
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 {
// Merge into existing value node
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv)
} else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
// New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue
}
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
}
}
@@ -984,32 +1097,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1
}
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
switch key {
case "generative-language-api-key",
"gemini-api-key",
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.
func isZeroValueNode(node *yaml.Node) bool {
if node == nil {
return true
}
switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode:
return node.Tag == "!!null"
default:
return false
switch node.Tag {
case "!!bool":
return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
}
return false
}
// deepCopyNode creates a deep copy of a yaml.Node graph.

View File

@@ -30,13 +30,13 @@ type SDKConfig struct {
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// nil means default (15 seconds). <= 0 disables keep-alives.
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// nil means default (2). 0 disables bootstrap retries.
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
// AccessConfig groups request authentication providers.

View File

@@ -42,6 +42,9 @@ type VertexCompatModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {

View File

@@ -73,17 +73,15 @@ func GinLogrusLogger() gin.HandlerFunc {
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}
var entry *log.Entry
if requestID != "" {
entry = log.WithField("request_id", requestID)
} else {
entry = log.WithField("request_id", "--------")
}
entry := log.WithField("request_id", requestID)
switch {
case statusCode >= http.StatusInternalServerError:

View File

@@ -10,6 +10,7 @@ import (
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -40,25 +41,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
reqID := ""
reqID := "--------"
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
reqID = id
}
callerFile := "unknown"
callerLine := 0
if entry.Caller != nil {
callerFile = filepath.Base(entry.Caller.File)
callerLine = entry.Caller.Line
level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", entry.Level.String())
levelStr := fmt.Sprintf("%-5s", level)
var formatted string
if reqID != "" {
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] | %s | %s\n", timestamp, levelStr, callerFile, callerLine, reqID, message)
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
} else {
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, levelStr, callerFile, callerLine, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
}
buffer.WriteString(formatted)
@@ -87,10 +85,30 @@ func SetupBaseLogger() {
})
}
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger()
writerMu.Lock()
@@ -99,10 +117,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
}
protectedPath := ""
if loggingToFile {
if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err)
}
@@ -126,7 +146,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout)
}
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil
}

View File

@@ -24,10 +24,11 @@ import (
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
return
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if err := os.MkdirAll(staticDir, 0o755); err != nil {
log.WithError(err).Warn("failed to prepare static directory for management asset")
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localPath := filepath.Join(staticDir, managementAssetName)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset")
return
}
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {

View File

@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
@@ -739,7 +740,8 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -771,7 +773,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
@@ -780,6 +782,32 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
}
}
}
return nil
}
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {

View File

@@ -625,6 +625,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
return models
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
// Parameters:
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
//
// Returns:
// - []*ModelInfo: List of available models for the provider
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return nil
}
r.mutex.RLock()
defer r.mutex.RUnlock()
type providerModel struct {
count int
info *ModelInfo
}
providerModels := make(map[string]*providerModel)
for clientID, clientProvider := range r.clientProviders {
if clientProvider != provider {
continue
}
modelIDs := r.clientModels[clientID]
if len(modelIDs) == 0 {
continue
}
clientInfos := r.clientModelInfos[clientID]
for _, modelID := range modelIDs {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
entry := providerModels[modelID]
if entry == nil {
entry = &providerModel{}
providerModels[modelID] = entry
}
entry.count++
if entry.info == nil {
if clientInfos != nil {
if info := clientInfos[modelID]; info != nil {
entry.info = info
}
}
if entry.info == nil {
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
entry.info = reg.Info
}
}
}
}
}
if len(providerModels) == 0 {
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
for modelID, entry := range providerModels {
if entry == nil || entry.count <= 0 {
continue
}
registration, ok := r.models[modelID]
expiredClients := 0
cooldownSuspended := 0
otherSuspended := 0
if ok && registration != nil {
if registration.QuotaExceededClients != nil {
for clientID, quotaTime := range registration.QuotaExceededClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
}
if registration.SuspendedClients != nil {
for clientID, reason := range registration.SuspendedClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
}
availableClients := entry.count
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
}
}
}
return result
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check

View File

@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
@@ -321,6 +323,11 @@ type translatedPayload struct {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
@@ -329,7 +336,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)
payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -76,7 +76,8 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if strings.Contains(req.Model, "claude") {
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if isClaude {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
@@ -93,13 +94,18 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -188,13 +194,18 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -520,15 +531,22 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -676,6 +694,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -692,9 +712,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
payload = normalizeAntigravityThinking(req.Model, payload)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
@@ -1308,7 +1328,7 @@ func alias2ModelName(modelName string) string {
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte {
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
@@ -1320,7 +1340,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {

View File

@@ -49,33 +49,34 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
@@ -167,26 +168,27 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body)
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
@@ -310,21 +312,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
@@ -461,6 +456,19 @@ func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[str
return util.ApplyClaudeThinkingConfig(body, budget)
}
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
}
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized.

View File

@@ -49,18 +49,26 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -146,20 +154,28 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -246,20 +262,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
modelForCounting := req.Model
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting)
enc, err := tokenizerForCodexModel(model)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
@@ -520,3 +537,87 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
return
}
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -77,14 +77,19 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
action := "generateContent"
if req.Metadata != nil {
@@ -216,14 +221,19 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
projectID := resolveGeminiProjectID(auth)
@@ -318,7 +328,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attempt string) {
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -336,14 +346,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter.publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -365,12 +375,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
@@ -417,15 +427,17 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int
var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token()
if errTok != nil {

View File

@@ -77,19 +77,27 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -98,7 +106,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
}
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -173,21 +181,29 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -287,19 +303,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq)
@@ -398,6 +420,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
return base
}
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveGeminiConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
var attrs map[string]string
if auth != nil {

View File

@@ -120,10 +120,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -136,8 +139,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent"
if req.Metadata != nil {
@@ -146,7 +149,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -220,24 +223,32 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent"
if req.Metadata != nil {
@@ -250,7 +261,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -321,10 +332,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
@@ -337,11 +351,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -438,30 +452,38 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -552,8 +574,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -566,14 +586,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -641,21 +661,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -665,7 +688,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -808,3 +831,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
}
return tok.AccessToken, nil
}
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}

View File

@@ -79,9 +79,14 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = e.normalizeModel(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "stream", false)
url := githubCopilotBaseURL + githubCopilotChatPath
@@ -162,9 +167,14 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = e.normalizeModel(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "stream", true)
// Enable stream options for usage stats in stream
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)

View File

@@ -56,18 +56,21 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyIFlowThinkingConfig(body)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -147,24 +150,27 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -445,20 +451,85 @@ func ensureToolsArray(body []byte) []byte {
return updated
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
// is preserved in conversation history for iFlow models that support thinking.
// This is helpful for multi-turn conversations where the model may benefit from seeing
// its previous reasoning to maintain coherent thought chains.
//
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
// response (including reasoning_content) in message history for better context continuity.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
if !needsPreservation {
return body
}
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
// Check if any assistant message already has reasoning_content preserved
hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
//
// Model-specific handling:
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
if !effort.Exists() {
return body
}
model := strings.ToLower(gjson.GetBytes(body, "model").String())
val := strings.ToLower(strings.TrimSpace(effort.String()))
enableThinking := val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
body, _ = sjson.DeleteBytes(body, "thinking")
// GLM-4.6/4.7: Use chat_template_kwargs
if strings.HasPrefix(model, "glm-4") {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
if enableThinking {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
}
return body
}
// MiniMax M2/M2.1: Use reasoning_split
if strings.HasPrefix(model, "minimax-m2") {
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
return body
}
return body
}

View File

@@ -53,20 +53,21 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate inbound request to OpenAI format
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return resp, errValidate
}
@@ -149,20 +150,21 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" && modelOverride == "" {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return nil, errValidate
}

View File

@@ -14,32 +14,54 @@ import (
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(model) {
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
@@ -82,17 +104,11 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
return payload
}
// applyPayloadConfig applies payload default and override rules from configuration
// to the given JSON payload for the specified model.
// Defaults only fill missing fields, while overrides always overwrite existing values.
func applyPayloadConfig(cfg *config.Config, model string, payload []byte) []byte {
return applyPayloadConfigWithRoot(cfg, model, "", "", payload)
}
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload []byte) []byte {
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
@@ -105,6 +121,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
return payload
}
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
@@ -116,7 +137,10 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if fullPath == "" {
continue
}
if gjson.GetBytes(out, fullPath).Exists() {
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
@@ -124,6 +148,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.

View File

@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -50,17 +49,19 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
}
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -129,15 +130,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel != "" {
body, _ = sjson.SetBytes(body, "model", upstreamModel)
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
}
toolsResult := gjson.GetBytes(body, "tools")
@@ -147,7 +150,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfig(e.cfg, req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))

View File

@@ -19,7 +19,7 @@ type usageReporter struct {
provider string
model string
authID string
authIndex uint64
authIndex string
apiKey string
source string
requestedAt time.Time
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
CachedTokens: node.Get("cachedContentTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
var stopChunkWithoutUsage sync.Map
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
cleaned := jsonBytes
var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true
}
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true
}

View File

@@ -99,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
@@ -271,11 +279,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
if params.CandidatesTokenCount < 0 {

View File

@@ -247,10 +247,30 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -305,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}

View File

@@ -87,15 +87,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
@@ -181,12 +181,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}

View File

@@ -209,9 +209,12 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens)
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
return []string{template}
@@ -285,8 +288,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var messageID string
var model string
var createdAt int64
var inputTokens, outputTokens int64
var reasoningTokens int64
var stopReason string
var contentParts []string
var reasoningParts []string
@@ -303,9 +304,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
if usage := message.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
}
case "content_block_start":
@@ -368,11 +366,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int()
// Estimate reasoning tokens from accumulated thinking content
if len(reasoningParts) > 0 {
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
}
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
}
}
@@ -431,16 +432,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
// Set usage information including prompt tokens, completion tokens, and total tokens
totalTokens := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
// Add reasoning tokens to usage details if any reasoning content was processed
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
}
return out
}

View File

@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String()
textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" {
builder.WriteByte('\n')
}
builder.WriteString(text)
return true
})
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
}
instructionsText = builder.String()
if instructionsText != "" {
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
return true
})
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
}
// Fallback to given role if content types not decisive

View File

@@ -218,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -260,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}

View File

@@ -170,12 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}

View File

@@ -56,7 +56,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
}
// contents

View File

@@ -233,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
@@ -261,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
}
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
// Tool calls -> single model content with functionCall parts
@@ -302,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
}
}

View File

@@ -89,15 +89,15 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
@@ -182,12 +182,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
@@ -316,12 +318,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
}

View File

@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
MsgIndex int
CurrentMsgID string
TextBuf strings.Builder
ItemTextBuf strings.Builder
// reasoning aggregation
ReasoningOpened bool
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
st.ItemTextBuf.WriteString(t.String())
}
st.TextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
finalizeReasoning()
// Close message output if opened
if st.MsgOpened {
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
}

View File

@@ -10,6 +10,7 @@ import (
// MergeAdjacentMessages merges adjacent messages with the same role.
// This reduces API call complexity and improves compatibility.
// Based on AIClient-2-API implementation.
// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
if len(messages) <= 1 {
return messages
@@ -26,6 +27,12 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
currentRole := msg.Get("role").String()
lastRole := lastMsg.Get("role").String()
// Don't merge tool messages - each has a unique tool_call_id
if currentRole == "tool" || lastRole == "tool" {
merged = append(merged, msg)
continue
}
if currentRole == lastRole {
// Merge content from current message into last message
mergedContent := mergeMessageContent(lastMsg, msg)

View File

@@ -450,24 +450,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Merge adjacent messages with the same role
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
// Build tool_call_id to name mapping from assistant messages
toolCallIDToName := make(map[string]string)
for _, msg := range messagesArray {
if msg.Get("role").String() == "assistant" {
toolCalls := msg.Get("tool_calls")
if toolCalls.IsArray() {
for _, tc := range toolCalls.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
toolCallIDToName[id] = name
}
}
}
}
}
}
// Track pending tool results that should be attached to the next user message
// This is critical for LiteLLM-translated requests where tool results appear
// as separate "tool" role messages between assistant and user messages
var pendingToolResults []KiroToolResult
for i, msg := range messagesArray {
role := msg.Get("role").String()
@@ -480,6 +466,10 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "user":
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
// Merge any pending tool results from preceding "tool" role messages
toolResults = append(pendingToolResults, toolResults...)
pendingToolResults = nil // Reset pending tool results
if isLastMessage {
currentUserMsg = &userMsg
currentToolResults = toolResults
@@ -505,6 +495,24 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "assistant":
assistantMsg := buildAssistantMessageFromOpenAI(msg)
// If there are pending tool results, we need to insert a synthetic user message
// before this assistant message to maintain proper conversation structure
if len(pendingToolResults) > 0 {
syntheticUserMsg := KiroUserInputMessage{
Content: "Tool results provided.",
ModelID: modelID,
Origin: origin,
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: pendingToolResults,
},
}
history = append(history, KiroHistoryMessage{
UserInputMessage: &syntheticUserMsg,
})
pendingToolResults = nil
}
if isLastMessage {
history = append(history, KiroHistoryMessage{
AssistantResponseMessage: &assistantMsg,
@@ -524,7 +532,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
case "tool":
// Tool messages in OpenAI format provide results for tool_calls
// These are typically followed by user or assistant messages
// Process them and merge into the next user message's tool results
// Collect them as pending and attach to the next user message
toolCallID := msg.Get("tool_call_id").String()
content := msg.Get("content").String()
@@ -534,9 +542,21 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
Content: []KiroTextContent{{Text: content}},
Status: "success",
}
// Tool results should be included in the next user message
// For now, collect them and they'll be handled when we build the current message
currentToolResults = append(currentToolResults, toolResult)
// Collect pending tool results to attach to the next user message
pendingToolResults = append(pendingToolResults, toolResult)
}
}
}
// Handle case where tool results are at the end with no following user message
if len(pendingToolResults) > 0 {
currentToolResults = append(currentToolResults, pendingToolResults...)
// If there's no current user message, create a synthetic one for the tool results
if currentUserMsg == nil {
currentUserMsg = &KiroUserInputMessage{
Content: "Tool results provided.",
ModelID: modelID,
Origin: origin,
}
}
}
@@ -551,9 +571,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
var toolResults []KiroToolResult
var images []KiroImage
// Track seen toolCallIds to deduplicate
seenToolCallIDs := make(map[string]bool)
if content.IsArray() {
for _, part := range content.Array() {
partType := part.Get("type").String()
@@ -589,9 +606,6 @@ func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroU
contentBuilder.WriteString(content.String())
}
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
_ = seenToolCallIDs // Used for deduplication if needed
userMsg := KiroUserInputMessage{
Content: contentBuilder.String(),
ModelID: modelID,

View File

@@ -0,0 +1,386 @@
package openai
import (
"encoding/json"
"testing"
)
// TestToolResultsAttachedToCurrentMessage verifies that tool results from "tool" role messages
// are properly attached to the current user message (the last message in the conversation).
// This is critical for LiteLLM-translated requests where tool results appear as separate messages.
func TestToolResultsAttachedToCurrentMessage(t *testing.T) {
// OpenAI format request simulating LiteLLM's translation from Anthropic format
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user
// The last user message should have the tool results attached
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello, can you read a file for me?"},
{
"role": "assistant",
"content": "I'll read that file for you.",
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc123",
"content": "File contents: Hello World!"
},
{"role": "user", "content": "What did the file say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// The last user message becomes currentMessage
// History should have: user (first), assistant (with tool_calls)
t.Logf("History count: %d", len(payload.ConversationState.History))
if len(payload.ConversationState.History) != 2 {
t.Errorf("Expected 2 history entries (user + assistant), got %d", len(payload.ConversationState.History))
}
// Tool results should be attached to currentMessage (the last user message)
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx == nil {
t.Fatal("Expected currentMessage to have UserInputMessageContext with tool results")
}
if len(ctx.ToolResults) != 1 {
t.Fatalf("Expected 1 tool result in currentMessage, got %d", len(ctx.ToolResults))
}
tr := ctx.ToolResults[0]
if tr.ToolUseID != "call_abc123" {
t.Errorf("Expected toolUseId 'call_abc123', got '%s'", tr.ToolUseID)
}
if len(tr.Content) == 0 || tr.Content[0].Text != "File contents: Hello World!" {
t.Errorf("Tool result content mismatch, got: %+v", tr.Content)
}
}
// TestToolResultsInHistoryUserMessage verifies that when there are multiple user messages
// after tool results, the tool results are attached to the correct user message in history.
func TestToolResultsInHistoryUserMessage(t *testing.T) {
// Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user
// The first user after tool should have tool results in history
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"content": "I'll read the file.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "File result"
},
{"role": "user", "content": "Thanks for the file"},
{"role": "assistant", "content": "You're welcome"},
{"role": "user", "content": "Bye"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// History should have: user, assistant, user (with tool results), assistant
// CurrentMessage should be: last user "Bye"
t.Logf("History count: %d", len(payload.ConversationState.History))
// Find the user message in history with tool results
foundToolResults := false
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil {
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
if h.UserInputMessage.UserInputMessageContext != nil {
if len(h.UserInputMessage.UserInputMessageContext.ToolResults) > 0 {
foundToolResults = true
t.Logf(" Found %d tool results", len(h.UserInputMessage.UserInputMessageContext.ToolResults))
tr := h.UserInputMessage.UserInputMessageContext.ToolResults[0]
if tr.ToolUseID != "call_1" {
t.Errorf("Expected toolUseId 'call_1', got '%s'", tr.ToolUseID)
}
}
}
}
if h.AssistantResponseMessage != nil {
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
}
}
if !foundToolResults {
t.Error("Tool results were not attached to any user message in history")
}
}
// TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls
func TestToolResultsWithMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read two files for me"},
{
"role": "assistant",
"content": "I'll read both files.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/file1.txt\"}"
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/file2.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "Content of file 1"
},
{
"role": "tool",
"tool_call_id": "call_2",
"content": "Content of file 2"
},
{"role": "user", "content": "What do they say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
t.Logf("History count: %d", len(payload.ConversationState.History))
t.Logf("CurrentMessage content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content)
// Check if there are any tool results anywhere
var totalToolResults int
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
t.Logf("History[%d] user message has %d tool results", i, count)
totalToolResults += count
}
}
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx != nil {
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
totalToolResults += len(ctx.ToolResults)
} else {
t.Logf("CurrentMessage has no UserInputMessageContext")
}
if totalToolResults != 2 {
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
}
}
// TestToolResultsAtEndOfConversation verifies tool results are handled when
// the conversation ends with tool results (no following user message)
func TestToolResultsAtEndOfConversation(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read a file"},
{
"role": "assistant",
"content": "Reading the file.",
"tool_calls": [
{
"id": "call_end",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/test.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_end",
"content": "File contents here"
}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// When the last message is a tool result, a synthetic user message is created
// and tool results should be attached to it
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx == nil || len(ctx.ToolResults) == 0 {
t.Error("Expected tool results to be attached to current message when conversation ends with tool result")
} else {
if ctx.ToolResults[0].ToolUseID != "call_end" {
t.Errorf("Expected toolUseId 'call_end', got '%s'", ctx.ToolResults[0].ToolUseID)
}
}
}
// TestToolResultsFollowedByAssistant verifies handling when tool results are followed
// by an assistant message (no intermediate user message).
// This is the pattern from LiteLLM translation of Anthropic format where:
// user message has ONLY tool_result blocks -> LiteLLM creates tool messages
// then the next message is assistant
func TestToolResultsFollowedByAssistant(t *testing.T) {
// Sequence: user -> assistant (with tool_calls) -> tool -> tool -> assistant -> user
// This simulates LiteLLM's translation of:
// user: "Read files"
// assistant: [tool_use, tool_use]
// user: [tool_result, tool_result] <- becomes multiple "tool" role messages
// assistant: "I've read them"
// user: "What did they say?"
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Read two files for me"},
{
"role": "assistant",
"content": "I'll read both files.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/a.txt\"}"
}
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/tmp/b.txt\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "Contents of file A"
},
{
"role": "tool",
"tool_call_id": "call_2",
"content": "Contents of file B"
},
{
"role": "assistant",
"content": "I've read both files."
},
{"role": "user", "content": "What did they say?"}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
t.Logf("History count: %d", len(payload.ConversationState.History))
// Tool results should be attached to a synthetic user message or the history should be valid
var totalToolResults int
for i, h := range payload.ConversationState.History {
if h.UserInputMessage != nil {
t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
if h.UserInputMessage.UserInputMessageContext != nil {
count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
t.Logf(" Has %d tool results", count)
totalToolResults += count
}
}
if h.AssistantResponseMessage != nil {
t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
}
}
ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
if ctx != nil {
t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
totalToolResults += len(ctx.ToolResults)
}
if totalToolResults != 2 {
t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
}
}
// TestAssistantEndsConversation verifies handling when assistant is the last message
func TestAssistantEndsConversation(t *testing.T) {
input := []byte(`{
"model": "kiro-claude-opus-4-5-agentic",
"messages": [
{"role": "user", "content": "Hello"},
{
"role": "assistant",
"content": "Hi there!"
}
]
}`)
result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
var payload KiroPayload
if err := json.Unmarshal(result, &payload); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
// When assistant is last, a "Continue" user message should be created
if payload.ConversationState.CurrentMessage.UserInputMessage.Content == "" {
t.Error("Expected a 'Continue' message to be created when assistant is last")
}
}

View File

@@ -118,76 +118,125 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Handle content
if contentResult.Exists() && contentResult.IsArray() {
var contentItems []string
var reasoningParts []string // Accumulate thinking text for reasoning_content
var toolCalls []interface{}
var toolResults []string // Collect tool_result messages to emit after the main message
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "thinking":
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
if role == "assistant" {
thinkingText := util.GetThinkingText(part)
// Skip empty or whitespace-only thinking
if strings.TrimSpace(thinkingText) != "" {
reasoningParts = append(reasoningParts, thinkingText)
}
}
// Ignore thinking in user/system roles (AC4)
case "redacted_thinking":
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
case "text", "image":
if contentItem, ok := convertClaudeContentPart(part); ok {
contentItems = append(contentItems, contentItem)
}
case "tool_use":
// Convert to OpenAI tool call format
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
if role == "assistant" {
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
case "tool_result":
// Convert to OpenAI tool message format and add immediately to preserve order
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", part.Get("content").String())
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResults = append(toolResults, toolResultJSON)
}
return true
})
// Emit text/image content as one message
if len(contentItems) > 0 {
msgJSON := `{"role":"","content":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
contentValue := gjson.Get(msgJSON, "content")
hasContent := false
switch {
case !contentValue.Exists():
hasContent = false
case contentValue.Type == gjson.String:
hasContent = contentValue.String() != ""
case contentValue.IsArray():
hasContent = len(contentValue.Array()) > 0
default:
hasContent = contentValue.Raw != "" && contentValue.Raw != "null"
}
if hasContent {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
// Build reasoning content string
reasoningContent := ""
if len(reasoningParts) > 0 {
reasoningContent = strings.Join(reasoningParts, "\n\n")
}
// Emit tool calls in a separate assistant message
if role == "assistant" && len(toolCalls) > 0 {
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
hasContent := len(contentItems) > 0
hasReasoning := reasoningContent != ""
hasToolCalls := len(toolCalls) > 0
hasToolResults := len(toolResults) > 0
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
// then emit the current message's content.
for _, toolResultJSON := range toolResults {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
}
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
if role == "assistant" {
if hasContent || hasReasoning || hasToolCalls {
msgJSON := `{"role":"assistant"}`
// Add content (as array if we have items, empty string if reasoning-only)
if hasContent {
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
} else {
// Ensure content field exists for OpenAI compatibility
msgJSON, _ = sjson.Set(msgJSON, "content", "")
}
// Add reasoning_content if present
if hasReasoning {
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
}
// Add tool_calls if present (in same message as content)
if hasToolCalls {
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
}
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
} else {
// For non-assistant roles: emit content message if we have content
// If the message only contains tool_results (no text/image), we still processed them above
if hasContent {
msgJSON := `{"role":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
} else if hasToolResults && !hasContent {
// tool_results already emitted above, no additional user message needed
}
}
} else if contentResult.Exists() && contentResult.Type == gjson.String {
@@ -307,3 +356,43 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
return "", false
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
if !content.Exists() {
return ""
}
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
var parts []string
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
parts = append(parts, item.Raw)
}
return true
})
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
}
return content.Raw
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
return content.Raw
}
return content.Raw
}

View File

@@ -0,0 +1,500 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
// of Claude thinking content to OpenAI reasoning_content field.
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantReasoningContent string
wantHasReasoningContent bool
wantContentText string // Expected visible content text (if any)
wantHasContent bool
}{
{
name: "AC1: assistant message with thinking and text",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
{"type": "text", "text": "Here is my response."}
]
}]
}`,
wantReasoningContent: "Let me analyze this step by step...",
wantHasReasoningContent: true,
wantContentText: "Here is my response.",
wantHasContent: true,
},
{
name: "AC2: redacted_thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "redacted_thinking", "data": "secret"},
{"type": "text", "text": "Visible response."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Visible response.",
wantHasContent: true,
},
{
name: "AC3: thinking-only message preserved with reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Internal reasoning only."}
]
}]
}`,
wantReasoningContent: "Internal reasoning only.",
wantHasReasoningContent: true,
wantContentText: "",
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
wantHasContent: false,
},
{
name: "AC4: thinking in user role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "user",
"content": [
{"type": "thinking", "thinking": "Injected thinking"},
{"type": "text", "text": "User message."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "User message.",
wantHasContent: true,
},
{
name: "AC4: thinking in system role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "thinking", "thinking": "Injected system thinking"},
{"type": "text", "text": "System prompt."}
],
"messages": [{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}]
}`,
// System messages don't have reasoning_content mapping
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Hello",
wantHasContent: true,
},
{
name: "AC5: empty thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": ""},
{"type": "text", "text": "Response with empty thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with empty thinking.",
wantHasContent: true,
},
{
name: "AC5: whitespace-only thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": " \n\t "},
{"type": "text", "text": "Response with whitespace thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with whitespace thinking.",
wantHasContent: true,
},
{
name: "Multiple thinking parts concatenated",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "First thought."},
{"type": "thinking", "thinking": "Second thought."},
{"type": "text", "text": "Final answer."}
]
}]
}`,
wantReasoningContent: "First thought.\n\nSecond thought.",
wantHasReasoningContent: true,
wantContentText: "Final answer.",
wantHasContent: true,
},
{
name: "Mixed thinking and redacted_thinking",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Visible thought."},
{"type": "redacted_thinking", "data": "hidden"},
{"type": "text", "text": "Answer."}
]
}]
}`,
wantReasoningContent: "Visible thought.",
wantHasReasoningContent: true,
wantContentText: "Answer.",
wantHasContent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
// Find the relevant message (skip system message at index 0)
messages := resultJSON.Get("messages").Array()
if len(messages) < 2 {
if tt.wantHasReasoningContent || tt.wantHasContent {
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
}
return
}
// Check the last non-system message
var targetMsg gjson.Result
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() != "system" {
targetMsg = messages[i]
break
}
}
// Check reasoning_content
gotReasoningContent := targetMsg.Get("reasoning_content").String()
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
if gotHasReasoningContent != tt.wantHasReasoningContent {
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
// Check content
content := targetMsg.Get("content")
// content has meaningful content if it's a non-empty array, or a non-empty string
var gotHasContent bool
switch {
case content.IsArray():
gotHasContent = len(content.Array()) > 0
case content.Type == gjson.String:
gotHasContent = content.String() != ""
default:
gotHasContent = false
}
if gotHasContent != tt.wantHasContent {
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
}
if tt.wantHasContent && tt.wantContentText != "" {
// Find text content
var foundText string
content.ForEach(func(_, v gjson.Result) bool {
if v.Get("type").String() == "text" {
foundText = v.Get("text").String()
return false
}
return true
})
if foundText != tt.wantContentText {
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
}
}
})
}
}
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
// that a message with only thinking content is preserved (not dropped).
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}]
},
{
"role": "assistant",
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
},
{
"role": "user",
"content": [{"type": "text", "text": "Thanks"}]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
}
// Check the assistant message (index 2) has reasoning_content
assistantMsg := messages[2]
if assistantMsg.Get("role").String() != "assistant" {
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
}
if !assistantMsg.Get("reasoning_content").Exists() {
t.Error("Expected assistant message to have reasoning_content")
}
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "before"},
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
{"type": "text", "text": "after"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
if len(messages) != 4 {
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
}
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
}
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
}
if got := messages[2].Get("content").String(); got != "tool ok" {
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
}
// User message comes after tool message
if messages[3].Get("role").String() != "user" {
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
}
// User message should contain both "before" and "after" text
if got := messages[3].Get("content.0.text").String(); got != "before" {
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
}
if got := messages[3].Get("content.1.text").String(); got != "after" {
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// system + assistant(tool_calls) + tool(result)
if len(messages) != 3 {
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[2].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
}
toolContent := messages[2].Get("content").String()
parsed := gjson.Parse(toolContent)
if parsed.Get("foo").String() != "bar" {
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: content + tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls)
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "system" {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have both content and tool_calls in same message
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
}
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
}
// Content should have both pre and post text
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: all content, thinking, and tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have content with both pre and post
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
// Should have tool_calls
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
// Should have combined reasoning_content from both thinking blocks
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}

View File

@@ -480,15 +480,15 @@ func collectOpenAIReasoningTexts(node gjson.Result) []string {
switch node.Type {
case gjson.String:
if text := strings.TrimSpace(node.String()); text != "" {
if text := node.String(); text != "" {
texts = append(texts, text)
}
case gjson.JSON:
if text := node.Get("text"); text.Exists() {
if trimmed := strings.TrimSpace(text.String()); trimmed != "" {
texts = append(texts, trimmed)
if textStr := text.String(); textStr != "" {
texts = append(texts, textStr)
}
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
texts = append(texts, raw)
}
}

View File

@@ -6,6 +6,7 @@ package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
@@ -90,7 +91,7 @@ type modelStats struct {
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
AuthIndex uint64 `json:"auth_index"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0

View File

@@ -344,7 +344,7 @@ func cleanupRequiredFields(jsonStr string) string {
}
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas.
// Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
@@ -364,6 +364,9 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false
if !propsVal.Exists() {
@@ -381,8 +384,22 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
// or if we've already added a placeholder reason above.
if parentPath == "" {
continue
}
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
}
}

View File

@@ -127,8 +127,10 @@ func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing
"type": "object",
"description": "Accepts: null | object",
"properties": {
"_": { "type": "boolean" },
"kind": { "type": "string" }
}
},
"required": ["_"]
}
}
}`
@@ -614,71 +616,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
// propertyNames is used to validate object property names (e.g., must match a pattern)
// Gemini doesn't support this keyword and will reject requests containing it
input := `{
"type": "object",
"properties": {
"metadata": {
"type": "object",
"propertyNames": {
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"additionalProperties": {
"type": "string"
}
}
}
}`
expected := `{
"type": "object",
"properties": {
"metadata": {
"type": "object"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
// Verify propertyNames is completely removed
if strings.Contains(result, "propertyNames") {
t.Errorf("propertyNames keyword should be removed, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"config": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
}
}
}
}`
result := CleanJSONSchemaForGemini(input)
if strings.Contains(result, "propertyNames") {
t.Errorf("Nested propertyNames should be removed, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)

View File

@@ -251,9 +251,14 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
// modelsWithDefaultThinking lists models that should have thinking enabled by default
// when no explicit thinkingConfig is provided.
// Note: Gemini 3 models are NOT included here because per Google's official documentation:
// - thinkingLevel defaults to "high" (dynamic thinking)
// - includeThoughts defaults to false
//
// We should not override these API defaults; let users explicitly configure if needed.
var modelsWithDefaultThinking = map[string]bool{
"gemini-3-pro-preview": true,
"gemini-3-pro-image-preview": true,
// "gemini-3-pro-preview": true,
// "gemini-3-pro-image-preview": true,
// "gemini-3-flash-preview": true,
}
@@ -288,37 +293,73 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
// For standard Gemini API format (generationConfig.thinkingConfig path).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
if !IsGemini3Model(model) {
// Use the alias from metadata if available for model type detection
lookupModel := ResolveOriginalModel(model, metadata)
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
return body
}
// Determine which model to use for validation
checkModel := model
if IsGemini3Model(lookupModel) {
checkModel = lookupModel
}
// First try to get effort string from metadata
effort, ok := ReasoningEffortFromMetadata(metadata)
if !ok || effort == "" {
return body
if ok && effort != "" {
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
}
}
// Validate and apply the thinkingLevel
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
// Fallback: check for numeric budget and convert to thinkingLevel
budget, _, _, matched := ThinkingFromMetadata(metadata)
if matched && budget != nil {
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
return ApplyGeminiThinkingLevel(body, level, nil)
}
}
return body
}
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal))
// or numeric budget suffix (e.g., model(1000)) which gets converted to a thinkingLevel.
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
if !IsGemini3Model(model) {
// Use the alias from metadata if available for model type detection
lookupModel := ResolveOriginalModel(model, metadata)
if !IsGemini3Model(lookupModel) && !IsGemini3Model(model) {
return body
}
// Determine which model to use for validation
checkModel := model
if IsGemini3Model(lookupModel) {
checkModel = lookupModel
}
// First try to get effort string from metadata
effort, ok := ReasoningEffortFromMetadata(metadata)
if !ok || effort == "" {
return body
if ok && effort != "" {
if level, valid := ValidateGemini3ThinkingLevel(checkModel, effort); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
}
}
// Validate and apply the thinkingLevel
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
// Fallback: check for numeric budget and convert to thinkingLevel
budget, _, _, matched := ThinkingFromMetadata(metadata)
if matched && budget != nil {
if level, valid := ThinkingBudgetToGemini3Level(checkModel, *budget); valid {
return ApplyGeminiCLIThinkingLevel(body, level, nil)
}
}
return body
}
@@ -326,15 +367,17 @@ func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
if !ModelHasDefaultThinking(model) {
func ApplyDefaultThinkingIfNeededCLI(model string, metadata map[string]any, body []byte) []byte {
// Use the alias from metadata if available for model property lookup
lookupModel := ResolveOriginalModel(model, metadata)
if !ModelHasDefaultThinking(lookupModel) && !ModelHasDefaultThinking(model) {
return body
}
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
return body
}
// Gemini 3 models use thinkingLevel instead of thinkingBudget
if IsGemini3Model(model) {
if IsGemini3Model(lookupModel) || IsGemini3Model(model) {
// Don't set a default - let the API use its dynamic default ("high")
// Only set includeThoughts
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)

View File

@@ -0,0 +1,56 @@
package util
import (
"testing"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"Normal", "valid_name", "valid_name"},
{"With Dots", "name.with.dots", "name.with.dots"},
{"With Colons", "name:with:colons", "name:with:colons"},
{"With Dashes", "name-with-dashes", "name-with-dashes"},
{"Mixed Allowed", "name.with_dots:colons-dashes", "name.with_dots:colons-dashes"},
{"Invalid Characters", "name!with@invalid#chars", "name_with_invalid_chars"},
{"Spaces", "name with spaces", "name_with_spaces"},
{"Non-ASCII", "name_with_你好_chars", "name_with____chars"},
{"Starts with digit", "123name", "_123name"},
{"Starts with dot", ".name", "_.name"},
{"Starts with colon", ":name", "_:name"},
{"Starts with dash", "-name", "_-name"},
{"Starts with invalid char", "!name", "_name"},
{"Exactly 64 chars", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Too long (65 chars)", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charactX", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Very long", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_limit_for_function_names", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_l"},
{"Starts with digit (64 chars total)", "1234567890123456789012345678901234567890123456789012345678901234", "_123456789012345678901234567890123456789012345678901234567890123"},
{"Starts with invalid char (64 chars total)", "!234567890123456789012345678901234567890123456789012345678901234", "_234567890123456789012345678901234567890123456789012345678901234"},
{"Empty", "", ""},
{"Single character invalid", "@", "_"},
{"Single character valid", "a", "a"},
{"Single character digit", "1", "_1"},
{"Single character underscore", "_", "_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeFunctionName(tt.input)
if got != tt.expected {
t.Errorf("SanitizeFunctionName(%q) = %v, want %v", tt.input, got, tt.expected)
}
// Verify Gemini compliance
if len(got) > 64 {
t.Errorf("SanitizeFunctionName(%q) result too long: %d", tt.input, len(got))
}
if len(got) > 0 {
first := got[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
t.Errorf("SanitizeFunctionName(%q) result starts with invalid char: %c", tt.input, first)
}
}
})
}
}

View File

@@ -12,9 +12,18 @@ func ModelSupportsThinking(model string) bool {
if model == "" {
return false
}
// First check the global dynamic registry
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil {
return info.Thinking != nil
}
// Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil {
return info.Thinking != nil
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil {
return cfg.Thinking != nil
}
return false
}
@@ -63,11 +72,19 @@ func thinkingRangeFromRegistry(model string) (found bool, min int, max int, zero
if model == "" {
return false, 0, 0, false, false
}
info := registry.GetGlobalRegistry().GetModelInfo(model)
if info == nil || info.Thinking == nil {
return false, 0, 0, false, false
// First check global dynamic registry
if info := registry.GetGlobalRegistry().GetModelInfo(model); info != nil && info.Thinking != nil {
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
}
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
// Fallback: check static model definitions
if info := registry.LookupStaticModelInfo(model); info != nil && info.Thinking != nil {
return true, info.Thinking.Min, info.Thinking.Max, info.Thinking.ZeroAllowed, info.Thinking.DynamicAllowed
}
// Fallback: check Antigravity static config
if cfg := registry.GetAntigravityModelConfig()[model]; cfg != nil && cfg.Thinking != nil {
return true, cfg.Thinking.Min, cfg.Thinking.Max, cfg.Thinking.ZeroAllowed, cfg.Thinking.DynamicAllowed
}
return false, 0, 0, false, false
}
// GetModelThinkingLevels returns the discrete reasoning effort levels for the model.

View File

@@ -7,10 +7,11 @@ import (
)
const (
ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model"
ThinkingBudgetMetadataKey = "thinking_budget"
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
ReasoningEffortMetadataKey = "reasoning_effort"
ThinkingOriginalModelMetadataKey = "thinking_original_model"
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
)
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
}
if metadata != nil {
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {
return base
}
}
}
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
if base := normalize(s); base != "" {

View File

@@ -8,12 +8,52 @@ import (
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
var functionNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
// SanitizeFunctionName ensures a function name matches the requirements for Gemini/Vertex AI.
// It replaces invalid characters with underscores, ensures it starts with a letter or underscore,
// and truncates it to 64 characters if necessary.
// Regex Rule: [^a-zA-Z0-9_.:-] replaced with _.
func SanitizeFunctionName(name string) string {
if name == "" {
return ""
}
// Replace invalid characters with underscore
sanitized := functionNameSanitizer.ReplaceAllString(name, "_")
// Ensure it starts with a letter or underscore
// Re-reading requirements: Must start with a letter or an underscore.
if len(sanitized) > 0 {
first := sanitized[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
// If it starts with an allowed character but not allowed at the beginning (digit, dot, colon, dash),
// we must prepend an underscore.
// To stay within the 64-character limit while prepending, we must truncate first.
if len(sanitized) >= 64 {
sanitized = sanitized[:63]
}
sanitized = "_" + sanitized
}
} else {
sanitized = "_"
}
// Truncate to 64 characters
if len(sanitized) > 64 {
sanitized = sanitized[:64]
}
return sanitized
}
// SetLogLevel configures the logrus log level based on the configuration.
// It sets the log level to DebugLevel if debug mode is enabled, otherwise to InfoLevel.
func SetLogLevel(cfg *config.Config) {

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"os"
"reflect"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)

View File

@@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
}
oldModels := SummarizeGeminiModels(o.Models)
newModels := SummarizeGeminiModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
}
oldModels := SummarizeClaudeModels(o.Models)
newModels := SummarizeClaudeModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
}
oldModels := SummarizeCodexModels(o.Models)
newModels := SummarizeCodexModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
@@ -185,10 +200,18 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
}
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
}
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...)
}
if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
@@ -301,3 +324,43 @@ func formatProxyURL(raw string) string {
}
return scheme + "://" + host
}
func equalStringSet(a, b []string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
aSet := make(map[string]struct{}, len(a))
for _, k := range a {
aSet[strings.TrimSpace(k)] = struct{}{}
}
bSet := make(map[string]struct{}, len(b))
for _, k := range b {
bSet[strings.TrimSpace(k)] = struct{}{}
}
if len(aSet) != len(bSet) {
return false
}
for k := range aSet {
if _, ok := bSet[k]; !ok {
return false
}
}
return true
}
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
// Comparison is done by count and content (upstream key and client keys).
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
return false
}
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
return false
}
}
return true
}

View File

@@ -56,6 +56,36 @@ func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
return hashJoined(keys)
}
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
func ComputeCodexModelsHash(models []config.CodexModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
func ComputeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 {

View File

@@ -81,6 +81,15 @@ func TestComputeClaudeModelsHash_Empty(t *testing.T) {
}
}
func TestComputeCodexModelsHash_Empty(t *testing.T) {
if got := ComputeCodexModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.ClaudeModel{
{Name: "m1", Alias: "a1"},
@@ -95,6 +104,20 @@ func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
}
}
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.CodexModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.CodexModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
@@ -157,3 +180,15 @@ func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeCodexModelsHash(models)
h2 := ComputeCodexModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}

View File

@@ -0,0 +1,121 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type GeminiModelsSummary struct {
hash string
count int
}
type ClaudeModelsSummary struct {
hash string
count int
}
type CodexModelsSummary struct {
hash string
count int
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
if len(models) == 0 {
return GeminiModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return GeminiModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeClaudeModels hashes Claude model aliases for change detection.
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
if len(models) == 0 {
return ClaudeModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return ClaudeModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeCodexModels hashes Codex model aliases for change detection.
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
if len(models) == 0 {
return CodexModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return CodexModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin
count: len(entries),
}
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeVertexModels hashes vertex-compatible models for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, m := range models {
name := strings.TrimSpace(m.Name)
alias := strings.TrimSpace(m.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}

View File

@@ -0,0 +1,98 @@
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type OAuthModelMappingsSummary struct {
hash string
count int
}
// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel.
func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]OAuthModelMappingsSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeOAuthModelMappingList(v)
}
if len(out) == 0 {
return nil
}
return out
}
// DiffOAuthModelMappingChanges compares OAuth model mappings maps.
func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) {
oldSummary := SummarizeOAuthModelMappings(oldMap)
newSummary := SummarizeOAuthModelMappings(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary {
if len(list) == 0 {
return OAuthModelMappingsSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, mapping := range list {
name := strings.ToLower(strings.TrimSpace(mapping.Name))
alias := strings.ToLower(strings.TrimSpace(mapping.Alias))
if name == "" || alias == "" {
continue
}
key := name + "->" + alias
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
normalized = append(normalized, key)
}
if len(normalized) == 0 {
return OAuthModelMappingsSummary{}
}
sort.Strings(normalized)
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
return OAuthModelMappingsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(normalized),
}
}

View File

@@ -66,6 +66,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
if base != "" {
attrs["base_url"] = base
}
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(entry.Headers, attrs)
a := &coreauth.Auth{
ID: id,
@@ -151,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{

View File

@@ -104,8 +104,8 @@ func BuildErrorResponseBody(status int, errText string) []byte {
// Returning 0 disables keep-alives (default when unset).
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
seconds = *cfg.Streaming.KeepAliveSeconds
if cfg != nil {
seconds = cfg.Streaming.KeepAliveSeconds
}
if seconds <= 0 {
return 0
@@ -116,8 +116,8 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
retries = *cfg.Streaming.BootstrapRetries
if cfg != nil {
retries = cfg.Streaming.BootstrapRetries
}
if retries < 0 {
retries = 0
@@ -619,7 +619,22 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
}
body := BuildErrorResponseBody(status, errText)
c.Set("API_RESPONSE", bytes.Clone(body))
// Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
var previous []byte
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
previous = bytes.Clone(existingBytes)
}
}
appendAPIResponse(c, body)
trimmedErrText := strings.TrimSpace(errText)
trimmedBody := bytes.TrimSpace(body)
if len(previous) > 0 {
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
c.Set("API_RESPONSE", previous)
}
}
if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json")

View File

@@ -94,12 +94,11 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
bootstrapRetries := 1
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: &bootstrapRetries,
BootstrapRetries: 1,
},
}, manager, nil)
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")

67
sdk/api/management.go Normal file
View File

@@ -0,0 +1,67 @@
// Package api exposes helpers for embedding CLIProxyAPI.
//
// It wraps internal management handler types so external projects can integrate
// management endpoints without importing internal packages.
package api
import (
"github.com/gin-gonic/gin"
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
type ManagementTokenRequester interface {
RequestAnthropicToken(*gin.Context)
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
GetAuthStatus(c *gin.Context)
}
type managementTokenRequester struct {
handler *internalmanagement.Handler
}
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
return &managementTokenRequester{
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
}
}
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
m.handler.RequestAnthropicToken(c)
}
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
m.handler.RequestGeminiCLIToken(c)
}
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
m.handler.RequestCodexToken(c)
}
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
m.handler.RequestIFlowToken(c)
}
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
m.handler.RequestIFlowCookieToken(c)
}
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
m.handler.GetAuthStatus(c)
}

View File

@@ -111,6 +111,9 @@ type Manager struct {
requestRetry atomic.Int32
maxRetryInterval atomic.Int64
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
modelNameMappings atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
@@ -203,10 +206,10 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {
return nil, nil
}
auth.EnsureIndex()
if auth.ID == "" {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
@@ -221,7 +224,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
return nil, nil
}
m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
@@ -263,7 +266,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -302,7 +304,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -341,7 +342,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -413,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -474,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -535,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
@@ -595,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
keys := []string{
util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
}
var out map[string]any
for _, key := range keys {
@@ -640,13 +644,20 @@ func (m *Manager) normalizeProviders(providers []string) []string {
return result
}
// rotateProviders returns a rotated view of the providers list starting from the
// current offset for the model, and atomically increments the offset for the next call.
// This ensures concurrent requests get different starting providers.
func (m *Manager) rotateProviders(model string, providers []string) []string {
if len(providers) == 0 {
return nil
}
m.mu.RLock()
// Atomic read-and-increment: get current offset and advance cursor in one lock
m.mu.Lock()
offset := m.providerOffsets[model]
m.mu.RUnlock()
m.providerOffsets[model] = (offset + 1) % len(providers)
m.mu.Unlock()
if len(providers) > 0 {
offset %= len(providers)
}
@@ -662,19 +673,6 @@ func (m *Manager) rotateProviders(model string, providers []string) []string {
return rotated
}
func (m *Manager) advanceProviderCursor(model string, providers []string) {
if len(providers) == 0 {
m.mu.Lock()
delete(m.providerOffsets, model)
m.mu.Unlock()
return
}
m.mu.Lock()
current := m.providerOffsets[model]
m.providerOffsets[model] = (current + 1) % len(providers)
m.mu.Unlock()
}
func (m *Manager) retrySettings() (int, time.Duration) {
if m == nil {
return 0, 0

View File

@@ -0,0 +1,171 @@
package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
if len(mappings) == 0 {
return &modelNameMappingTable{}
}
out := &modelNameMappingTable{
reverse: make(map[string]map[string]string, len(mappings)),
}
for rawChannel, entries := range mappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = name
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
table := compileModelNameMappingTable(mappings)
// atomic.Value requires non-nil store values.
if table == nil {
table = &modelNameMappingTable{}
}
m.modelNameMappings.Store(table)
}
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
// and returns the resolved model along with updated metadata. If a mapping exists,
// the returned model is the upstream model and metadata contains the original
// requested model for response translation.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if upstreamModel == "" {
return requestedModel, metadata
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
for k, v := range metadata {
out[k] = v
}
}
// Store the requested alias (e.g., "gp") so downstream can use it to look up
// model metadata from the global registry where it was registered under this alias.
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
return upstreamModel, out
}
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return ""
}
channel := modelMappingChannel(auth)
if channel == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requestedModel))
if key == "" {
return ""
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
original := strings.TrimSpace(rev[key])
if original == "" || strings.EqualFold(original, requestedModel) {
return ""
}
return original
}
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelMappingChannel for the actual channel resolution.
func modelMappingChannel(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return OAuthModelMappingChannel(provider, authKind)
}
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model mappings (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
func OAuthModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}

View File

@@ -1,11 +1,12 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
@@ -15,8 +16,8 @@ import (
type Auth struct {
// ID uniquely identifies the auth record across restarts.
ID string `json:"id"`
// Index is a monotonically increasing runtime identifier used for diagnostics.
Index uint64 `json:"-"`
// Index is a stable runtime identifier derived from auth metadata (not persisted).
Index string `json:"-"`
// Provider is the upstream provider key (e.g. "gemini", "claude").
Provider string `json:"provider"`
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
@@ -94,12 +95,6 @@ type ModelState struct {
UpdatedAt time.Time `json:"updated_at"`
}
var authIndexCounter atomic.Uint64
func nextAuthIndex() uint64 {
return authIndexCounter.Add(1) - 1
}
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth {
if a == nil {
@@ -128,15 +123,41 @@ func (a *Auth) Clone() *Auth {
return &copyAuth
}
// EnsureIndex returns the global index, assigning one if it was not set yet.
func (a *Auth) EnsureIndex() uint64 {
if a == nil {
return 0
func stableAuthIndex(seed string) string {
seed = strings.TrimSpace(seed)
if seed == "" {
return ""
}
if a.indexAssigned {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
// EnsureIndex returns a stable index derived from the auth file name or API key.
func (a *Auth) EnsureIndex() string {
if a == nil {
return ""
}
if a.indexAssigned && a.Index != "" {
return a.Index
}
idx := nextAuthIndex()
seed := strings.TrimSpace(a.FileName)
if seed != "" {
seed = "file:" + seed
} else if a.Attributes != nil {
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
seed = "api_key:" + apiKey
}
}
if seed == "" {
if id := strings.TrimSpace(a.ID); id != "" {
seed = "id:" + id
} else {
return ""
}
}
idx := stableAuthIndex(seed)
a.Index = idx
a.indexAssigned = true
return idx

View File

@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
service := &Service{
cfg: b.cfg,

View File

@@ -13,6 +13,7 @@ type ModelRegistry interface {
ClearModelQuotaExceeded(clientID, modelID string)
ClientSupportsModel(clientID, modelID string) bool
GetAvailableModels(handlerType string) []map[string]any
GetAvailableModelsByProvider(provider string) []*ModelInfo
}
// GlobalModelRegistry returns the shared registry instance.

View File

@@ -556,6 +556,9 @@ func (s *Service) Run(ctx context.Context) error {
s.cfgMu.Lock()
s.cfg = newCfg
s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
}
s.rebindExecutors()
}
@@ -681,6 +684,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
return
}
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -706,6 +714,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "gemini":
models = registry.GetGeminiModels()
if entry := s.resolveConfigGeminiKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildGeminiConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
@@ -745,6 +756,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "codex":
models = registry.GetOpenAIModels()
if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
@@ -842,6 +856,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
}
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
if len(models) > 0 {
key := provider
if key == "" {
@@ -1113,17 +1128,22 @@ func matchWildcard(pattern, value string) bool {
return true
}
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
type modelEntry interface {
GetName() string
GetAlias() string
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for i := range models {
model := models[i]
name := strings.TrimSpace(model.GetName())
alias := strings.TrimSpace(model.GetAlias())
if alias == "" {
alias = name
}
@@ -1139,52 +1159,135 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if display == "" {
display = alias
}
out = append(out, &ModelInfo{
info := &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "vertex",
Type: "vertex",
OwnedBy: ownedBy,
Type: modelType,
DisplayName: display,
})
}
if name != "" {
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
info.Thinking = upstream.Thinking
}
}
out = append(out, info)
}
return out
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(entry.Models))
seen := make(map[string]struct{}, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if alias == "" {
alias = name
}
if alias == "" {
return buildConfigModels(entry.Models, "google", "vertex")
}
func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "google", "gemini")
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "anthropic", "claude")
}
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "openai", "openai")
}
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return name
}
oldID = strings.TrimSpace(oldID)
newID = strings.TrimSpace(newID)
if oldID == "" || newID == "" {
return name
}
if strings.EqualFold(oldID, newID) {
return name
}
if strings.HasSuffix(trimmed, "/"+oldID) {
prefix := strings.TrimSuffix(trimmed, oldID)
return prefix + newID
}
if trimmed == "models/"+oldID {
return "models/" + newID
}
return name
}
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
return models
}
mappings := cfg.OAuthModelMappings[channel]
if len(mappings) == 0 {
return models
}
forward := make(map[string]string, len(mappings))
for i := range mappings {
name := strings.TrimSpace(mappings[i].Name)
alias := strings.TrimSpace(mappings[i].Alias)
if name == "" || alias == "" {
continue
}
key := strings.ToLower(alias)
if _, exists := seen[key]; exists {
if strings.EqualFold(name, alias) {
continue
}
seen[key] = struct{}{}
display := name
if display == "" {
display = alias
key := strings.ToLower(name)
if _, exists := forward[key]; exists {
continue
}
out = append(out, &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: "claude",
Type: "claude",
DisplayName: display,
})
forward[key] = alias
}
if len(forward) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
id := strings.TrimSpace(model.ID)
if id == "" {
continue
}
mappedID := id
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
mappedID = strings.TrimSpace(to)
}
uniqueKey := strings.ToLower(mappedID)
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
if mappedID == id {
out = append(out, model)
continue
}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
}
return out
}

View File

@@ -14,7 +14,7 @@ type Record struct {
Model string
APIKey string
AuthID string
AuthIndex uint64
AuthIndex string
Source string
RequestedAt time.Time
Failed bool

View File

@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
type TLSConfig = internalconfig.TLSConfig
type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode
type ModelNameMapping = internalconfig.ModelNameMapping
type PayloadConfig = internalconfig.PayloadConfig
type PayloadRule = internalconfig.PayloadRule
type PayloadModelRule = internalconfig.PayloadModelRule

View File

@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
}
}
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
h, configPath := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
// Verify it was persisted to disk
loaded, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("failed to load config from disk: %v", err)
}
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
}
entry := loaded.AmpCode.UpstreamAPIKeys[0]
if entry.UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
}
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
}
// Verify it is returned by GET /ampcode
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
}
}
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
// Seed with one entry
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
deleteBody := `{"value":[]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpUpstreamAPIKeyEntry
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)

View File

@@ -0,0 +1,211 @@
package test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
)
// TestModelAliasThinkingSuffix tests the 32 test cases defined in docs/thinking_suffix_test_cases.md
// These tests verify the thinking suffix parsing and application logic across different providers.
func TestModelAliasThinkingSuffix(t *testing.T) {
tests := []struct {
id int
name string
provider string
requestModel string
suffixType string
expectedField string // "thinkingBudget", "thinkingLevel", "budget_tokens", "reasoning_effort", "enable_thinking"
expectedValue any
upstreamModel string // The upstream model after alias resolution
isAlias bool
}{
// === 1. Antigravity Provider ===
// 1.1 Budget-only models (Gemini 2.5)
{1, "antigravity_original_numeric", "antigravity", "gemini-2.5-computer-use-preview-10-2025(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", false},
{2, "antigravity_alias_numeric", "antigravity", "gp(1000)", "numeric", "thinkingBudget", 1000, "gemini-2.5-computer-use-preview-10-2025", true},
// 1.2 Budget+Levels models (Gemini 3)
{3, "antigravity_original_numeric_to_level", "antigravity", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{4, "antigravity_original_level", "antigravity", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{5, "antigravity_alias_numeric_to_level", "antigravity", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{6, "antigravity_alias_level", "antigravity", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 2. Gemini CLI Provider ===
// 2.1 Budget-only models
{7, "gemini_cli_original_numeric", "gemini-cli", "gemini-2.5-pro(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", false},
{8, "gemini_cli_alias_numeric", "gemini-cli", "g25p(8192)", "numeric", "thinkingBudget", 8192, "gemini-2.5-pro", true},
// 2.2 Budget+Levels models
{9, "gemini_cli_original_numeric_to_level", "gemini-cli", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{10, "gemini_cli_original_level", "gemini-cli", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{11, "gemini_cli_alias_numeric_to_level", "gemini-cli", "gf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{12, "gemini_cli_alias_level", "gemini-cli", "gf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 3. Vertex Provider ===
// 3.1 Budget-only models
{13, "vertex_original_numeric", "vertex", "gemini-2.5-pro(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", false},
{14, "vertex_alias_numeric", "vertex", "vg25p(16384)", "numeric", "thinkingBudget", 16384, "gemini-2.5-pro", true},
// 3.2 Budget+Levels models
{15, "vertex_original_numeric_to_level", "vertex", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{16, "vertex_original_level", "vertex", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{17, "vertex_alias_numeric_to_level", "vertex", "vgf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{18, "vertex_alias_level", "vertex", "vgf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 4. AI Studio Provider ===
// 4.1 Budget-only models
{19, "aistudio_original_numeric", "aistudio", "gemini-2.5-pro(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", false},
{20, "aistudio_alias_numeric", "aistudio", "ag25p(12000)", "numeric", "thinkingBudget", 12000, "gemini-2.5-pro", true},
// 4.2 Budget+Levels models
{21, "aistudio_original_numeric_to_level", "aistudio", "gemini-3-flash-preview(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{22, "aistudio_original_level", "aistudio", "gemini-3-flash-preview(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", false},
{23, "aistudio_alias_numeric_to_level", "aistudio", "agf(1000)", "numeric", "thinkingLevel", "low", "gemini-3-flash-preview", true},
{24, "aistudio_alias_level", "aistudio", "agf(low)", "level", "thinkingLevel", "low", "gemini-3-flash-preview", true},
// === 5. Claude Provider ===
{25, "claude_original_numeric", "claude", "claude-sonnet-4-5-20250929(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", false},
{26, "claude_alias_numeric", "claude", "cs45(16384)", "numeric", "budget_tokens", 16384, "claude-sonnet-4-5-20250929", true},
// === 6. Codex Provider ===
{27, "codex_original_level", "codex", "gpt-5(high)", "level", "reasoning_effort", "high", "gpt-5", false},
{28, "codex_alias_level", "codex", "g5(high)", "level", "reasoning_effort", "high", "gpt-5", true},
// === 7. Qwen Provider ===
{29, "qwen_original_level", "qwen", "qwen3-coder-plus(high)", "level", "enable_thinking", true, "qwen3-coder-plus", false},
{30, "qwen_alias_level", "qwen", "qcp(high)", "level", "enable_thinking", true, "qwen3-coder-plus", true},
// === 8. iFlow Provider ===
{31, "iflow_original_level", "iflow", "glm-4.7(high)", "level", "reasoning_effort", "high", "glm-4.7", false},
{32, "iflow_alias_level", "iflow", "glm(high)", "level", "reasoning_effort", "high", "glm-4.7", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Step 1: Parse model suffix (simulates SDK layer normalization)
// For "gp(1000)" -> requestedModel="gp", metadata={thinking_budget: 1000}
requestedModel, metadata := util.NormalizeThinkingModel(tt.requestModel)
// Verify suffix was parsed
if metadata == nil && (tt.suffixType == "numeric" || tt.suffixType == "level") {
t.Errorf("Case #%d: NormalizeThinkingModel(%q) metadata is nil", tt.id, tt.requestModel)
return
}
// Step 2: Simulate OAuth model mapping
// Real flow: applyOAuthModelMapping stores requestedModel (the alias) in metadata
if tt.isAlias {
if metadata == nil {
metadata = make(map[string]any)
}
metadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Step 3: Verify metadata extraction
switch tt.suffixType {
case "numeric":
budget, _, _, matched := util.ThinkingFromMetadata(metadata)
if !matched {
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
return
}
if budget == nil {
t.Errorf("Case #%d: expected budget in metadata", tt.id)
return
}
// For thinkingBudget/budget_tokens, verify the parsed budget value
if tt.expectedField == "thinkingBudget" || tt.expectedField == "budget_tokens" {
expectedBudget := tt.expectedValue.(int)
if *budget != expectedBudget {
t.Errorf("Case #%d: budget = %d, want %d", tt.id, *budget, expectedBudget)
}
}
// For thinkingLevel (Gemini 3), verify conversion from budget to level
if tt.expectedField == "thinkingLevel" {
level, ok := util.ThinkingBudgetToGemini3Level(tt.upstreamModel, *budget)
if !ok {
t.Errorf("Case #%d: ThinkingBudgetToGemini3Level failed", tt.id)
return
}
expectedLevel := tt.expectedValue.(string)
if level != expectedLevel {
t.Errorf("Case #%d: converted level = %q, want %q", tt.id, level, expectedLevel)
}
}
case "level":
_, _, effort, matched := util.ThinkingFromMetadata(metadata)
if !matched {
t.Errorf("Case #%d: ThinkingFromMetadata did not match", tt.id)
return
}
if effort == nil {
t.Errorf("Case #%d: expected effort in metadata", tt.id)
return
}
if tt.expectedField == "thinkingLevel" || tt.expectedField == "reasoning_effort" {
expectedEffort := tt.expectedValue.(string)
if *effort != expectedEffort {
t.Errorf("Case #%d: effort = %q, want %q", tt.id, *effort, expectedEffort)
}
}
}
// Step 4: Test Gemini-specific thinkingLevel conversion for Gemini 3 models
if tt.expectedField == "thinkingLevel" && util.IsGemini3Model(tt.upstreamModel) {
body := []byte(`{"request":{"contents":[]}}`)
// Build metadata simulating real OAuth flow:
// - requestedModel (alias like "gf") is stored in model_mapping_original_model
// - upstreamModel is passed as the model parameter
testMetadata := make(map[string]any)
if tt.isAlias {
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata {
testMetadata[k] = v
}
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(tt.upstreamModel, testMetadata, body)
levelVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
expectedLevel := tt.expectedValue.(string)
if !levelVal.Exists() {
t.Errorf("Case #%d: expected thinkingLevel in result", tt.id)
} else if levelVal.String() != expectedLevel {
t.Errorf("Case #%d: thinkingLevel = %q, want %q", tt.id, levelVal.String(), expectedLevel)
}
}
// Step 5: Test Gemini 2.5 thinkingBudget application using real ApplyThinkingMetadataCLI flow
if tt.expectedField == "thinkingBudget" && util.IsGemini25Model(tt.upstreamModel) {
body := []byte(`{"request":{"contents":[]}}`)
// Build metadata simulating real OAuth flow:
// - requestedModel (alias like "gp") is stored in model_mapping_original_model
// - upstreamModel is passed as the model parameter
testMetadata := make(map[string]any)
if tt.isAlias {
// Real flow: applyOAuthModelMapping stores requestedModel (the alias)
testMetadata[util.ModelMappingOriginalModelMetadataKey] = requestedModel
}
// Copy parsed metadata (thinking_budget, reasoning_effort, etc.)
for k, v := range metadata {
testMetadata[k] = v
}
// Use the exported ApplyThinkingMetadataCLI which includes the fallback logic
result := executor.ApplyThinkingMetadataCLI(body, testMetadata, tt.upstreamModel)
budgetVal := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget")
expectedBudget := tt.expectedValue.(int)
if !budgetVal.Exists() {
t.Errorf("Case #%d: expected thinkingBudget in result", tt.id)
} else if int(budgetVal.Int()) != expectedBudget {
t.Errorf("Case #%d: thinkingBudget = %d, want %d", tt.id, int(budgetVal.Int()), expectedBudget)
}
}
})
}
}