Compare commits

...

179 Commits

Author SHA1 Message Date
Luis Pater
51611c25d7 Merge branch 'router-for-me:main' into main 2026-01-21 22:12:28 +08:00
Luis Pater
eb1bbaa63b Merge pull request #119 from linlang781/main
支持Kiro sso idc
2026-01-21 22:11:58 +08:00
yuechenglong.5
4c8026ac3d chore(build): 更新 .gitignore 文件
- 添加 *.bak 文件扩展名到忽略列表
2026-01-21 21:38:47 +08:00
gogoing1024
8aeb4b7d54 Merge pull request #1 from gogoing1024/main
Merge pull request #1 from linlang781/main
2026-01-21 21:09:34 +08:00
gogoing1024
b2172cb047 Merge pull request #1 from linlang781/main
1
2026-01-21 21:07:24 +08:00
Luis Pater
ef4508dbc8 refactor(cache, translator): remove session ID from signature caching and clean up logic 2026-01-21 13:37:10 +08:00
Luis Pater
f775e46fe2 refactor(translator): remove session ID logic from signature caching and associated tests 2026-01-21 12:45:07 +08:00
Luis Pater
65ad5c0c9d refactor(cache): simplify signature caching by removing sessionID parameter 2026-01-21 12:38:05 +08:00
Luis Pater
88bf4e77ec fix(translator): update HasValidSignature to require modelName parameter for improved validation 2026-01-21 11:31:37 +08:00
yuechenglong.5
194f66ca9c feat(kiro): 添加后台令牌刷新通知机制
- 在 BackgroundRefresher 中添加 onTokenRefreshed 回调函数和并发安全锁
- 实现 WithOnTokenRefreshed 选项函数用于设置刷新成功回调
- 在 RefreshManager 中添加 SetOnTokenRefreshed 方法支持运行时更新回调
- 为 KiroExecutor 添加 reloadAuthFromFile 方法实现文件重新加载回退机制
- 在 Watcher 中实现 NotifyTokenRefreshed 方法处理刷新通知并更新内存Auth对象
- 通过 Service.GetWatcher 连接刷新器回调到 Watcher 通知链路
- 添加方案A和方案B双重保障解决后台刷新与内存对象时间差问题
2026-01-21 11:03:07 +08:00
Luis Pater
a4f8015caa test(logging): add unit tests for GinLogrusRecovery middleware panic handling 2026-01-21 10:57:27 +08:00
Luis Pater
ffd129909e Merge pull request #1130 from router-for-me/agty
fix(executor): only strip maxOutputTokens for non-claude models
2026-01-21 10:50:39 +08:00
hkfires
9332316383 fix(translator): preserve thinking blocks by skipping signature 2026-01-21 10:49:20 +08:00
hkfires
6dcbbf64c3 fix(executor): only strip maxOutputTokens for non-claude models 2026-01-21 10:49:20 +08:00
yuechenglong.5
c9aa1ff99d Merge remote-tracking branch 'origin/main'
# Conflicts:
#	internal/auth/kiro/oauth_web.go
2026-01-21 10:31:55 +08:00
Luis Pater
2ce3553612 feat(cache): handle gemini family in signature cache with fallback validator logic 2026-01-21 10:11:21 +08:00
Luis Pater
2e14f787d4 feat(translator): enhance ConvertGeminiRequestToAntigravity with model name and refine reasoning block handling 2026-01-21 08:31:23 +08:00
Luis Pater
523b41ccd2 test(responses): add comprehensive tests for SSE event ordering and response transformations 2026-01-21 07:08:59 +08:00
Luis Pater
c6fa1d0e67 Merge pull request #1117 from router-for-me/cache
fix(translator): enhance signature cache clearing logic and update test cases with model name
2026-01-20 23:18:48 +08:00
Luis Pater
ac56e1e88b Merge pull request #1116 from bexcodex/fix/antigravity
Fix antigravity malformed_function_call
2026-01-20 22:40:00 +08:00
781456868@qq.com
a9ee971e1c fix(kiro): improve auto-refresh and IDC auth file handling
Amp-Thread-ID: https://ampcode.com/threads/T-019bdb94-80e3-7302-be0f-a69937826d13
Co-authored-by: Amp <amp@ampcode.com>
2026-01-20 21:57:45 +08:00
781456868@qq.com
73cef3a25a Merge remote-tracking branch 'upstream/main' 2026-01-20 21:57:16 +08:00
hkfires
9b72ea9efa fix(translator): enhance signature cache clearing logic and update test cases with model name 2026-01-20 20:02:29 +08:00
bexcodex
9f364441e8 Fix antigravity malformed_function_call 2026-01-20 19:54:54 +08:00
Luis Pater
e49a1c07bf chore(translator): update cache functions to include model name parameter in tests 2026-01-20 18:36:51 +08:00
Luis Pater
5364a2471d fix(endpoint_compat): update GetModelInfo to include missing parameter for improved registry compatibility 2026-01-20 13:56:57 +08:00
Luis Pater
fef4fdb0eb Merge pull request #117 from router-for-me/plus
v6.7.15
2026-01-20 13:50:53 +08:00
Luis Pater
c2bf600a39 Merge branch 'main' into plus 2026-01-20 13:50:41 +08:00
Luis Pater
8d9f4edf9b feat(translator): unify model group references by introducing GetModelGroup helper function 2026-01-20 13:45:25 +08:00
Luis Pater
020e61d0da feat(translator): improve signature handling by associating with model name in cache functions 2026-01-20 13:31:36 +08:00
Luis Pater
6184c43319 Fixed: #1109
feat(translator): enhance session ID derivation with user_id parsing in Claude
2026-01-20 12:35:40 +08:00
Luis Pater
2cbe4a790c chore(translator): remove unnecessary whitespace in gemini_openai_response code 2026-01-20 11:47:33 +08:00
Luis Pater
68b3565d7b Merge branch 'main' into dev (PR #961) 2026-01-20 11:42:22 +08:00
Luis Pater
3f385a8572 feat(auth): add "antigravity" provider to ignored access_token fields in filestore 2026-01-20 11:38:31 +08:00
Luis Pater
9823dc35e1 feat(auth): hash account ID for improved uniqueness in credential filenames 2026-01-20 11:37:52 +08:00
Luis Pater
059bfee91b feat(auth): add hashed account ID to credential filenames for team plans 2026-01-20 11:36:29 +08:00
Luis Pater
7beaf0eaa2 Merge pull request #869 2026-01-20 11:16:53 +08:00
Luis Pater
1fef90ff58 Merge pull request #877 from zhiqing0205/main
feat(codex): include plan type in auth filename
2026-01-20 11:11:25 +08:00
Luis Pater
8447fd27a0 fix(login): remove emojis from interactive prompt messages 2026-01-20 11:09:56 +08:00
Luis Pater
7831cba9f6 refactor(claude): remove redundant system instructions check in Claude executor 2026-01-20 11:02:52 +08:00
Luis Pater
e02b2d58d5 Merge pull request #868 2026-01-20 10:57:24 +08:00
Luis Pater
28726632a9 Merge pull request #861 from umairimtiaz9/fix/gemini-cli-backend-project-id
fix(auth): use backend project ID for free tier Gemini CLI OAuth users
2026-01-20 10:32:17 +08:00
yuechenglong.5
0f63d973be Merge remote-tracking branch 'origin/main' 2026-01-20 10:20:03 +08:00
Luis Pater
3b26129c82 Merge pull request #1108 from router-for-me/modelinfo
feat(registry): support provider-specific model info lookup
2026-01-20 10:18:42 +08:00
Luis Pater
d4bb4e6624 refactor(antigravity): remove unused client signature handling in thinking objects 2026-01-20 10:17:55 +08:00
yuechenglong.5
fa2abd560a chore: cherry-pick 文档更新和删除测试文件
- docs: 添加 Kiro OAuth web 认证端点说明 (ace7c0c)
- chore: 删除包含敏感数据的测试文件 (8f06f6a)
- 保留本地修改: refresh_manager, token_repository 等
2026-01-20 10:17:39 +08:00
Luis Pater
0766c49f93 Merge pull request #994 from adrenjc/fix/cross-model-thinking-signature
fix(antigravity): prevent corrupted thought signature when switching models
2026-01-20 10:14:05 +08:00
Luis Pater
a7ffc77e3d Merge branch 'dev' into fix/cross-model-thinking-signature 2026-01-20 10:10:43 +08:00
hkfires
e641fde25c feat(registry): support provider-specific model info lookup 2026-01-20 10:01:17 +08:00
yuechenglong.5
564c2d763e Merge upstream/main (08779cc) - sync with original repo updates 2026-01-20 09:52:11 +08:00
Luis Pater
5717c7f2f4 Merge pull request #1103 from dinhkarate/feat/imagen
feat(vertex): add Imagen image generation model support
2026-01-20 07:11:18 +08:00
dinhkarate
8734d4cb90 feat(vertex): add Imagen image generation model support
Add support for Imagen 3.0 and 4.0 image generation models in Vertex AI:

- Add 5 Imagen model definitions (4.0, 4.0-ultra, 4.0-fast, 3.0, 3.0-fast)
- Implement :predict action routing for Imagen models
- Convert Imagen request/response format to match Gemini structure like gemini-3-pro-image
- Transform prompts to Imagen's instances/parameters format
- Convert base64 image responses to Gemini-compatible inline data
2026-01-20 01:26:37 +07:00
Luis Pater
08779cc8a8 Merge branch 'router-for-me:main' into main 2026-01-19 21:00:58 +08:00
Luis Pater
5baa753539 Merge pull request #1099 from router-for-me/claude
refactor(claude): move max_tokens constraint enforcement to Apply method
2026-01-19 20:55:59 +08:00
781456868@qq.com
92fb6b012a feat(kiro): add manual token refresh button to OAuth web UI
Amp-Thread-ID: https://ampcode.com/threads/T-019bd642-9806-75d8-9101-27812e0eb6ab
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:55:51 +08:00
Luis Pater
ead98e4bca Merge pull request #1101 from router-for-me/argy
fix(executor): stop rewriting thinkingLevel for gemini
2026-01-19 20:55:22 +08:00
781456868@qq.com
8f06f6a9ed chore: remove test files containing sensitive data
Amp-Thread-ID: https://ampcode.com/threads/T-019bd618-7e42-715a-960d-dd45425851e3
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:31:33 +08:00
781456868@qq.com
ace7c0ccb4 docs: add Kiro OAuth web authentication endpoint /v0/oauth/kiro 2026-01-19 20:28:40 +08:00
781456868@qq.com
f87fe0a0e8 feat: proactive token refresh 10 minutes before expiry
Amp-Thread-ID: https://ampcode.com/threads/T-019bd618-7e42-715a-960d-dd45425851e3
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:09:38 +08:00
781456868@qq.com
87edc6f35e Merge remote-tracking branch 'upstream/main' 2026-01-19 20:09:17 +08:00
hkfires
1d2fe55310 fix(executor): stop rewriting thinkingLevel for gemini 2026-01-19 19:49:39 +08:00
hkfires
c175821cc4 feat(registry): expand antigravity model config
Remove static Name mapping and add entries for claude-sonnet-4-5,
tab_flash_lite_preview, and gpt-oss-120b-medium configs
2026-01-19 19:32:00 +08:00
hkfires
239a28793c feat(claude): clamp thinking budget to max_tokens constraints 2026-01-19 16:32:20 +08:00
hkfires
c421d653e7 refactor(claude): move max_tokens constraint enforcement to Apply method 2026-01-19 15:50:35 +08:00
Luis Pater
2542c2920d Merge pull request #1096 from router-for-me/usage
feat(translator): report cached token usage in Claude output
2026-01-19 11:52:18 +08:00
hkfires
52e46ced1b fix(translator): avoid forcing RFC 8259 system prompt 2026-01-19 11:33:27 +08:00
hkfires
cf9daf470c feat(translator): report cached token usage in Claude output 2026-01-19 11:23:44 +08:00
Luis Pater
ac7738bdeb Merge pull request #114 from router-for-me/plus
v6.7.9
2026-01-19 04:03:26 +08:00
Luis Pater
2d9f6c104c Merge branch 'main' into plus 2026-01-19 04:03:17 +08:00
Luis Pater
5d0460ece2 Merge pull request #112 from clstb/main
Add Github Copilot support for management interface
2026-01-19 04:02:09 +08:00
Luis Pater
140d6211cc feat(translator): add reasoning state tracking and improve reasoning summary handling
- Introduced `oaiToResponsesStateReasoning` to track reasoning data.
- Enhanced logic for emitting reasoning summary events and managing state transitions.
- Updated output generation to handle multiple reasoning entries consistently.
2026-01-19 03:58:28 +08:00
Luis Pater
60f9a1442c Merge pull request #1088 from router-for-me/thinking
Thinking
2026-01-18 17:01:59 +08:00
hkfires
cb6caf3f87 fix(thinking): update ValidateConfig to include fromSuffix parameter and adjust budget validation logic 2026-01-18 16:37:14 +08:00
781456868@qq.com
c9301a6d18 docs: update README with new features and Docker deployment guide 2026-01-18 15:07:29 +08:00
781456868@qq.com
0e77e93e5d feat: add Kiro OAuth web, rate limiter, metrics, fingerprint, background refresh and model converter 2026-01-18 15:04:29 +08:00
Luis Pater
99c7abbbf1 Merge pull request #1067 from router-for-me/auth-files
refactor(auth): simplify filename prefixes for qwen and iflow tokens
2026-01-18 13:41:59 +08:00
Luis Pater
8f511ac33c Merge pull request #1076 from sususu98/fix/antigravity-enum-string
fix(antigravity): convert non-string enum values to strings for Gemini API
2026-01-18 13:40:53 +08:00
Luis Pater
1046152119 Merge pull request #1068 from 0xtbug/dev
docs(readme): add ZeroLimit to projects based on CLIProxyAPI
2026-01-18 13:37:50 +08:00
Luis Pater
f88228f1c5 Merge pull request #1081 from router-for-me/thinking
Refine thinking validation and cross‑provider payload conversion
2026-01-18 13:34:28 +08:00
Luis Pater
62e2b672d9 refactor(logging): centralize log directory resolution logic
- Introduced `ResolveLogDirectory` function in `logging` package to standardize log directory determination across components.
- Replaced redundant logic in `server`, `global_logger`, and `handlers` with the new utility function.
2026-01-18 12:40:57 +08:00
hkfires
03005b5d29 refactor(thinking): add Gemini family provider grouping for strict validation 2026-01-18 11:30:53 +08:00
hkfires
c7e8830a56 refactor(thinking): pass source and target formats to ApplyThinking for cross-format validation
Update ApplyThinking signature to accept fromFormat and toFormat parameters
instead of a single provider string. This enables:

- Proper level-to-budget conversion when source is level-based (openai/codex)
  and target is budget-based (gemini/claude)
- Strict budget range validation when source and target formats match
- Level clamping to nearest supported level for cross-format requests
- Format alias resolution in SDK translator registry for codex/openai-response

Also adds ErrBudgetOutOfRange error code and improves iflow config extraction
to fall back to openai format when iflow-specific config is not present.
2026-01-18 10:30:15 +08:00
hkfires
d5ef4a6d15 refactor(translator): remove registry model lookups from thinking config conversions 2026-01-18 10:30:14 +08:00
hkfires
97b67e0e49 test(thinking): split E2E coverage into suffix and body parameter test functions
Refactor thinking configuration tests by separating model name suffix-based
scenarios from request body parameter-based scenarios into distinct test
functions with independent case numbering.

Architectural improvements:
- Extract thinkingTestCase struct to package level for shared usage
- Add getTestModels() helper returning complete model fixture set
- Introduce runThinkingTests() runner with protocol-specific field detection
- Register level-subset-model fixture with constrained low/high level support
- Extend iflow protocol handling for glm-test and minimax-test models
- Add same-protocol strict boundary validation cases (80-89)
- Replace error responses with clamped values for boundary-exceeding budgets
2026-01-18 10:30:14 +08:00
sususu98
dd6d78cb31 fix(antigravity): convert non-string enum values to strings for Gemini API
Gemini API requires all enum values in function declarations to be
strings. Some MCP tools (e.g., roxybrowser) define schemas with numeric
enums like `"enum": [0, 1, 2]`, causing INVALID_ARGUMENT errors.

Add convertEnumValuesToStrings() to automatically convert numeric and
boolean enum values to their string representations during schema
transformation.
2026-01-18 02:00:02 +00:00
Luis Pater
46433a25f8 fix(translator): add check for empty text to prevent invalid serialization in gemini and antigravity 2026-01-18 00:50:10 +08:00
clstb
b4e070697d feat: support github copilot in management ui 2026-01-17 17:22:45 +01:00
Tubagus
c8843edb81 Update README_CN.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-17 11:33:29 +07:00
Tubagus
f89feb881c Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-17 11:33:18 +07:00
Tubagus
dbba71028e docs(readme): add ZeroLimit to projects based on CLIProxyAPI 2026-01-17 11:30:15 +07:00
Tubagus
8549a92e9a docs(readme): add ZeroLimit to projects based on CLIProxyAPI
Added ZeroLimit app to the list of projects in README.
2026-01-17 11:29:22 +07:00
hkfires
109cffc010 refactor(auth): simplify filename prefixes for qwen and iflow tokens 2026-01-17 12:20:58 +08:00
Luis Pater
f8f3ad84fc Fixed: #1064
feat(translator): improve system message handling and content indexing across translators

- Updated logic for processing system messages in `claude`, `gemini`, `gemini-cli`, and `antigravity` translators.
- Introduced indexing for `systemInstruction.parts` to ensure proper ordering and handling of multi-part content.
- Added safeguards for accurate content transformation and serialization.
2026-01-17 05:40:56 +08:00
Luis Pater
93d7883513 Merge pull request #110 from PancakeZik/fix/system-prompt-reinjection
fix: prevent system prompt re-injection on subsequent turns
2026-01-17 05:19:11 +08:00
Luis Pater
015a3e8a83 Merge branch 'router-for-me:main' into main 2026-01-17 05:17:38 +08:00
Luis Pater
bc7167e9fe feat(runtime): add model alias support and enhance payload rule matching
- Introduced `payloadModelAliases` and `payloadModelCandidates` functions to support model aliases for improved flexibility.
- Updated rule matching logic to handle multiple model candidates.
- Refactored variable naming in executor to improve code clarity and consistency.
2026-01-17 05:05:24 +08:00
Luis Pater
384578a88c feat(cliproxy, gemini): improve ID matching logic and enrich normalized model output
- Enhanced ID matching in `cliproxy` by adding additional conditions to better handle ID equality cases.
- Updated `gemini` handlers to include `displayName` and `description` in normalized models for enriched metadata.
2026-01-17 04:44:09 +08:00
Joao
6b074653f2 fix: prevent system prompt re-injection on subsequent turns
When tool results are sent back to the model, the system prompt was being
re-injected into the user message content, causing the model to think the
user had pasted the system prompt again. This was especially noticeable
after multiple tool uses.

The fix checks if there is conversation history (len(history) > 0). If so,
it's a subsequent turn and we skip system prompt injection. The system
prompt is only injected on the first turn (len(history) == 0).

This ensures:
- First turn: system prompt is injected
- Tool result turns: system prompt is NOT re-injected
- New conversations: system prompt is injected fresh
2026-01-16 20:16:44 +00:00
Luis Pater
65b4e1ec6c feat(codex): enable instruction toggling and update role terminology
- Added conditional logic for Codex instruction injection based on configuration.
- Updated role terminology from "user" to "developer" for better alignment with context.
2026-01-17 04:12:29 +08:00
Luis Pater
06afa29f2d Merge branch 'router-for-me:main' into main 2026-01-16 20:01:35 +08:00
Luis Pater
6600d58ba2 feat(codex): enhance input transformation and remove unused safety_identifier field
- Added logic to transform `inputResults` into structured JSON for improved processing.
- Removed redundant `safety_identifier` field in executor payload to streamline requests.
2026-01-16 19:59:01 +08:00
Luis Pater
25e9be3ced Merge pull request #103 from ChrAlpha/feat/add-gpt-5.2-codex-copilot
feat(openai): responses API support for GitHub Copilot provider
2026-01-16 18:33:53 +08:00
Luis Pater
ccb2aaf2fe Merge branch 'router-for-me:main' into main 2026-01-16 18:29:56 +08:00
Luis Pater
961c6f67da Merge pull request #100 from novadev94/fix/readd_kiro_auto
fix(kiro): re-add kiro-auto to registry
2026-01-16 18:29:43 +08:00
Luis Pater
dc4305f75a Merge pull request #107 from zccing/main
fix(kiro): correct Amazon Q endpoint URL path
2026-01-16 18:28:45 +08:00
Chén Mù
4dc7af5a5d Merge pull request #1054 from router-for-me/codex
fix(codex): ensure instructions field exists
2026-01-16 15:40:12 +08:00
hkfires
902bea24b4 fix(codex): ensure instructions field exists 2026-01-16 15:38:10 +08:00
Cc
778cf4af9e feat(kiro): add agent-mode and optout headers for non-IDC auth
- Add x-amzn-kiro-agent-mode: vibe for non-IDC auth (Social, Builder ID)
  IDC auth continues to use "spec" mode
- Add x-amzn-codewhisperer-optout: true for all auth types
  This opts out of data sharing for service improvement (privacy)

These changes align with other Kiro implementations (kiro.rs, KiroGate,
kiro-gateway, AIClient-2-API) and make requests more similar to real
Kiro IDE clients.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 14:21:38 +08:00
hkfires
c3ef46f409 feat(config): supplement missing default aliases during antigravity migration 2026-01-16 13:37:46 +08:00
Cc
4721c58d9c fix(kiro): correct Amazon Q endpoint URL path
The Q endpoint was using `/` which caused all requests to fail with
400 or UnknownOperationException. Changed to `/generateAssistantResponse`
which is the correct path for the Q endpoint.

This fix restores the Q endpoint failover functionality.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 13:22:43 +08:00
Luis Pater
aa0b63e214 refactor(config): clarify Codex instruction toggle documentation 2026-01-16 12:50:09 +08:00
Luis Pater
3c4e7997c3 Merge branch 'router-for-me:main' into main 2026-01-16 12:47:23 +08:00
Luis Pater
1afc3a5f65 feat(auth): add support for kiro OAuth model alias
- Introduced `kiro` channel and alias resolution in `oauth_model_alias` logic.
- Updated supported channels documentation and examples to include `kiro` and `github-copilot`.
- Enhanced unit tests to validate `kiro` alias functionality.
2026-01-16 12:47:05 +08:00
Luis Pater
ea3d22831e refactor(codex): update terminology to "official instructions" for clarity 2026-01-16 12:44:57 +08:00
Luis Pater
3b4d6d359b Merge pull request #1049 from router-for-me/codex
feat(codex): add config toggle for codex instructions injection
2026-01-16 12:38:35 +08:00
hkfires
48cba39a12 feat(codex): add config toggle for codex instructions injection 2026-01-16 12:30:12 +08:00
Luis Pater
bca244df67 Merge branch 'router-for-me:main' into main 2026-01-16 11:37:33 +08:00
Luis Pater
cec4e251bd feat(translator): preserve text field in serialized output during chat completions processing 2026-01-16 11:35:34 +08:00
Luis Pater
526dd866ba refactor(gemini): replace static model handling with dynamic model registry lookup 2026-01-16 10:39:16 +08:00
Luis Pater
c29839d2ed Merge remote-tracking branch 'origin/main' into pr-104
# Conflicts:
#	config.example.yaml
#	internal/config/config.go
#	sdk/cliproxy/auth/model_name_mappings.go
2026-01-16 09:40:07 +08:00
Luis Pater
b31ddc7bf1 Merge branch 'dev' 2026-01-16 08:21:59 +08:00
Chén Mù
22e1ad3d8a Merge pull request #1018 from pikeman20/main
feat(docker): use environment variables for volume paths
2026-01-16 08:19:23 +08:00
Luis Pater
f571b1deb0 feat(config): add support for raw JSON payload rules
- Introduced `default-raw` and `override-raw` rules to handle raw JSON values.
- Enhanced `PayloadConfig` to validate and sanitize raw JSON payload rules.
- Updated executor logic to apply `default-raw` and `override-raw` rules.
- Extended example YAML to include usage of raw JSON rules.
2026-01-16 08:15:28 +08:00
Luis Pater
67f8732683 Merge pull request #1033 from router-for-me/reasoning
Refactor thinking
2026-01-15 20:33:13 +08:00
hkfires
2b387e169b feat(iflow): add iflow-rome model definition 2026-01-15 20:23:55 +08:00
hkfires
199cf480b0 refactor(thinking): remove support for non-standard thinking configurations
This change removes the translation logic for several non-standard, proprietary extensions used to configure thinking/reasoning. Specifically, support for `extra_body.google.thinking_config` and the Anthropic-style `thinking` object has been dropped from the OpenAI request translators.

This simplification streamlines the translators, focusing them on the standard `reasoning_effort` parameter. It also removes the need to look up model information from the registry within these components.

BREAKING CHANGE: Support for non-standard thinking configurations via `extra_body.google.thinking_config` and the Anthropic-style `thinking` object has been removed. Clients should now use the standard `reasoning_effort` parameter to control reasoning.
2026-01-15 19:32:12 +08:00
ChrAlpha
18daa023cb fix(openai): improve error handling for response conversion failures 2026-01-15 19:13:54 +08:00
hkfires
4ad6189487 refactor(thinking): extract antigravity logic into a dedicated provider 2026-01-15 19:08:22 +08:00
ChrAlpha
8950d92682 feat(openai): implement endpoint resolution and response handling for Chat and Responses models 2026-01-15 18:30:01 +08:00
hkfires
fe5b3c80cb refactor(config): rename oauth-model-mappings to oauth-model-alias 2026-01-15 18:03:26 +08:00
ChrAlpha
0ffcce3ec8 feat(registry): add supported endpoints for GitHub Copilot models
Enhance model definitions by including supported API endpoints for each model. This allows for better integration and usage tracking with the GitHub Copilot API.
2026-01-15 16:32:28 +08:00
hkfires
e0ffec885c fix(aistudio): remove levels from model definitions 2026-01-15 16:06:46 +08:00
hkfires
ff4ff6bc2f feat(thinking): support zero as a valid thinking budget for capable models 2026-01-15 15:41:10 +08:00
ChrAlpha
f4fcfc5867 feat(registry): add GPT-5.2-Codex model to GitHub Copilot provider
Add gpt-5.2-codex model definition to GetGitHubCopilotModels() function,
  enabling access to OpenAI GPT-5.2 Codex through the GitHub Copilot API.
2026-01-15 14:14:09 +08:00
Luis Pater
7248f65c36 feat(auth): prevent filestore writes on unchanged metadata
- Added `metadataEqualIgnoringTimestamps` to compare metadata while ignoring volatile fields.
- Prevented redundant writes caused by changes in timestamp-related fields.
- Improved efficiency in filestore operations by skipping unnecessary updates.
2026-01-15 14:05:23 +08:00
hkfires
5c40a2db21 refactor(thinking): simplify ModeNone and budget validation logic 2026-01-15 14:03:08 +08:00
Luis Pater
d6111344c5 Merge branch 'router-for-me:main' into main 2026-01-15 13:30:28 +08:00
Luis Pater
086eb3df7a refactor(auth): simplify file handling logic and remove redundant comparison functions
feat(auth): fetch and update Antigravity project ID from metadata during filestore operations

- Added support to retrieve and update `project_id` using the access token if missing in metadata.
- Integrated HTTP client to fetch project ID dynamically.
- Enhanced metadata persistence logic.
2026-01-15 13:29:14 +08:00
hkfires
ee2976cca0 refactor(thinking): improve logging for user-defined models 2026-01-15 13:06:41 +08:00
hkfires
8bc6df329f fix(auth): apply API key model mapping to request model 2026-01-15 13:06:41 +08:00
hkfires
bcd4d9595f fix(thinking): refine ModeNone handling based on provider capabilities 2026-01-15 13:06:41 +08:00
hkfires
5a77b7728e refactor(thinking): improve budget clamping and logging with provider/model context 2026-01-15 13:06:41 +08:00
hkfires
1fbbba6f59 feat(logging): order log fields for improved readability 2026-01-15 13:06:41 +08:00
hkfires
847be0e99d fix(auth): use base model name for auth matching by stripping suffix 2026-01-15 13:06:41 +08:00
hkfires
f6a2d072e6 refactor(thinking): refine configuration logging 2026-01-15 13:06:41 +08:00
hkfires
ed8b0f25ee fix(thinking): use LookupModelInfo for model data 2026-01-15 13:06:41 +08:00
hkfires
6e4a602c60 fix(thinking): map reasoning_effort to thinkingConfig 2026-01-15 13:06:40 +08:00
hkfires
2262479365 refactor(thinking): remove legacy utilities and simplify model mapping 2026-01-15 13:06:40 +08:00
hkfires
33d66959e9 test(thinking): remove legacy unit and integration tests 2026-01-15 13:06:40 +08:00
hkfires
7f1b2b3f6e fix(thinking): improve model lookup and validation 2026-01-15 13:06:40 +08:00
hkfires
40ee065eff fix(thinking): use static lookup to avoid alias issues 2026-01-15 13:06:40 +08:00
hkfires
a75fb6af90 refactor(antigravity): remove hardcoded model aliases 2026-01-15 13:06:39 +08:00
hkfires
72f2125668 fix(executor): properly handle thinking application errors 2026-01-15 13:06:39 +08:00
hkfires
e8f5888d8e fix(thinking): fix auth matching for thinking suffix and json field conflicts 2026-01-15 13:06:39 +08:00
hkfires
0b06d637e7 refactor: improve thinking logic 2026-01-15 13:06:39 +08:00
Luis Pater
496f6770a5 Merge branch 'router-for-me:main' into main 2026-01-15 12:09:22 +08:00
Luis Pater
5a7e5bd870 feat(auth): add Antigravity onboarding with tier selection
- Updated `ideType` to `ANTIGRAVITY` in request payload.
- Introduced tier-selection logic to determine default tier for onboarding.
- Added `antigravityOnboardUser` function for project ID retrieval via polling.
- Enhanced error handling and response decoding for onboarding flow.
2026-01-15 11:43:02 +08:00
Luis Pater
6f8a8f8136 feat(selector): add priority support for auth selection 2026-01-15 07:08:24 +08:00
pikeman20
5df195ea82 feat(docker): use environment variables for volume paths
This change introduces environment variable interpolation for volume paths, allowing users to customize where configuration, authentication, and log data are stored.

Why: Makes the project easier to deploy on various hosting environments that require decoupled data management without needing to modify the core docker-compose.yml..

Key points:

Defaults to existing paths (./config.yaml, ./auths, ./logs) to ensure zero breaking changes for current users.

Follows the existing naming convention used in the project.

Enhances portability for CI/CD and cloud-native deployments.
2026-01-15 05:42:51 +07:00
Nova
f82f70df5c fix(kiro): re-add kiro-auto to registry
Reference: https://github.com/router-for-me/CLIProxyAPIPlus/pull/16
Revert: a594338bc5
2026-01-15 03:26:22 +07:00
Luis Pater
5a2bf191fc Merge pull request #98 from router-for-me/plus
v6.6.105
2026-01-15 03:31:04 +08:00
Luis Pater
a235fb1507 Merge branch 'main' into plus 2026-01-15 03:30:56 +08:00
Luis Pater
0d66522ed8 Merge pull request #95 from ZqinKing/main
feat(kiro): 实现动态工具压缩功能
2026-01-15 03:29:49 +08:00
Luis Pater
b163f8ed9e Fixed: #1004
feat(translator): add function name to response output item serialization

- Included `item.name` in the serialized response output to enhance output item handling.
2026-01-15 03:27:00 +08:00
ZqinKing
83e5f60b8b fix(kiro): scale description compression by needed size
Compute a size-reduction based keep ratio and use it to trim
tool descriptions, avoiding forced minimum truncation when the
target size already fits. This aligns compression with actual
payload reduction needs and prevents over-compression.
2026-01-14 16:22:46 +08:00
ZqinKing
5b433f962f feat(kiro): 实现动态工具压缩功能
## 背景
当 Claude Code 发送过多工具信息时,可能超出 Kiro API 请求限制导致 500 错误。
现有的工具描述截断(KiroMaxToolDescLen = 10237)只能限制单个工具的描述长度,
无法解决整体工具列表过大的问题。

## 解决方案
实现动态工具压缩功能,采用两步压缩策略:
1. 先检查原始大小,超过 20KB 才进行压缩
2. 第一步:简化 input_schema,只保留 type/enum/required 字段
3. 第二步:按比例缩短 description(最短 50 字符)
4. 保留全部工具和 skills 可调用,不丢弃任何工具

## 新增文件
- internal/translator/kiro/claude/tool_compression.go
  - calculateToolsSize(): 计算工具列表的 JSON 序列化大小
  - simplifyInputSchema(): 简化 input_schema,递归处理嵌套 properties
  - compressToolDescription(): 按比例压缩描述,支持 UTF-8 安全截断
  - compressToolsIfNeeded(): 主压缩函数,实现两步压缩策略

- internal/translator/kiro/claude/tool_compression_test.go
  - 完整的单元测试覆盖所有新增函数
  - 测试 UTF-8 安全性
  - 测试压缩效果

## 修改文件
- internal/translator/kiro/common/constants.go
  - 新增 ToolCompressionTargetSize = 20KB (压缩目标大小阈值)
  - 新增 MinToolDescriptionLength = 50 (描述最短长度)

- internal/translator/kiro/claude/kiro_claude_request.go
  - 在 convertClaudeToolsToKiro() 函数末尾调用 compressToolsIfNeeded()

## 测试结果
- 70KB 工具压缩至 17KB (74.7% 压缩率)
- 所有单元测试通过

## 预期效果
- 80KB+ tools 压缩至 ~15KB
- 不影响工具调用功能
2026-01-14 11:07:07 +08:00
Luis Pater
a1da6ff5ac Fixed: #499 #985
feat(oauth): add support for customizable OAuth callback ports

- Introduced `oauth-callback-port` flag to override default callback ports.
- Updated SDK and login flows for `iflow`, `gemini`, `antigravity`, `codex`, `claude`, and `openai` to respect configurable callback ports.
- Refactored internal OAuth servers to dynamically assign ports based on the provided options.
- Revised tests and documentation to reflect the new flag and behavior.
2026-01-14 04:29:15 +08:00
adrenjc
5977af96a0 fix(antigravity): prevent corrupted thought signature when switching models
When switching from Claude models (e.g., Opus 4.5) to Gemini models
(e.g., Flash) mid-conversation via Antigravity OAuth, the client-provided
thinking signatures from Claude would cause "Corrupted thought signature"
errors since they are incompatible with Gemini API.

Changes:
- Remove fallback to client-provided signatures in thinking block handling
- Only use cached signatures (from same-session Gemini responses)
- Skip thinking blocks without valid cached signatures
- tool_use blocks continue to use skip_thought_signature_validator when
  no valid signature is available

This ensures cross-model switching works correctly while preserving
signature validation for same-model conversations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 18:24:05 +08:00
extremk
5bb9c2a2bd Add candidate count parameter to OpenAI request 2026-01-10 18:50:13 +08:00
extremk
0b5bbe9234 Add candidate count handling in OpenAI request 2026-01-10 18:49:29 +08:00
extremk
14c74e5e84 Handle 'n' parameter for candidate count in requests
Added handling for the 'n' parameter to set candidate count in generationConfig.
2026-01-10 18:48:33 +08:00
extremk
6448d0ee7c Add candidate count handling in OpenAI request 2026-01-10 18:47:41 +08:00
extremk
b0c17af2cf Enhance Gemini to OpenAI response conversion
Refactor response handling to support multiple candidates and improve parameter management.
2026-01-10 18:46:25 +08:00
zhiqing0205
aa8526edc0 fix(codex): use unicode title casing for plan 2026-01-06 10:24:02 +08:00
zhiqing0205
ac3ca0ad8e feat(codex): include plan type in auth filename 2026-01-06 02:25:56 +08:00
FakerL
08d21b76e2 Update sdk/auth/filestore.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 21:38:26 +08:00
Zhi Yang
33aa665555 fix(auth): persist access_token on refresh for providers that need it
Previously, metadataEqualIgnoringTimestamps() ignored access_token for all
providers, which prevented refreshed tokens from being persisted to disk/database.
This caused tokens to be lost on server restart for providers like iFlow.

This change makes the behavior provider-specific:
- Providers like gemini/gemini-cli that issue new tokens on every refresh and
  can re-fetch when needed will continue to ignore access_token (optimization)
- Other providers like iFlow will now persist access_token changes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-05 13:25:46 +00:00
maoring24
00280b6fe8 feat(claude): add native request cloaking for non-claude-code clients
integrate claude-cloak functionality to disguise api requests:
- add CloakConfig with mode (auto/always/never) and strict-mode options
- generate fake user_id in claude code format (user_[hex]_account__session_[uuid])
- inject claude code system prompt (configurable strict mode)
- obfuscate sensitive words with zero-width characters
- auto-detect claude code clients via user-agent

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 20:32:51 +08:00
CodeIgnitor
52760a4eaa fix(auth): use backend project ID for free tier Gemini CLI OAuth users
Fixes issue where free tier users cannot access Gemini 3 preview models
due to frontend/backend project ID mapping.

## Problem
Google's Gemini API uses a frontend/backend project mapping system for
free tier users:
- Frontend projects (e.g., gen-lang-client-*) are user-visible
- Backend projects (e.g., mystical-victor-*) host actual API access
- Only backend projects have access to preview models (gemini-3-*)

Previously, CLIProxyAPI ignored the backend project ID returned by
Google's onboarding API and kept using the frontend ID, preventing
access to preview models.

## Solution
### CLI (internal/cmd/login.go)
- Detect free tier users (gen-lang-client-* projects or FREE/LEGACY tier)
- Show interactive prompt allowing users to choose frontend or backend
- Default to backend (recommended for preview model access)
- Pro users: maintain original behavior (keep frontend ID)

### Web UI (internal/api/handlers/management/auth_files.go)
- Detect free tier users using same logic
- Automatically use backend project ID (recommended choice)
- Pro users: maintain original behavior (keep frontend ID)

### Deduplication (internal/cmd/login.go)
- Add deduplication when user selects ALL projects
- Prevents redundant API calls when multiple frontend projects map to
  same backend
- Skips duplicate project IDs in activation loop

## Impact
- Free tier users: Can now access gemini-3-pro-preview and
  gemini-3-flash-preview models
- Pro users: No change in behavior (backward compatible)
- Only affects Gemini CLI OAuth (not antigravity or API key auth)

## Testing
- Tested with free tier account selecting single project
- Tested with free tier account selecting ALL projects
- Verified deduplication prevents redundant onboarding calls
- Confirmed pro user behavior unchanged
2026-01-05 02:41:24 +05:00
167 changed files with 18750 additions and 5577 deletions

1
.gitignore vendored
View File

@@ -50,3 +50,4 @@ _bmad-output/*
# macOS
.DS_Store
._*
*.bak

View File

@@ -13,6 +13,82 @@ The Plus release stays in lockstep with the mainline features.
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
## New Features (Plus Enhanced)
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
- **Device Fingerprint**: Device fingerprint generation for enhanced security
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
- **Usage Checker**: Real-time usage monitoring and quota management
- **Model Converter**: Unified model name conversion across providers
- **UTF-8 Stream Processing**: Improved streaming response handling
## Kiro Authentication
### Web-based OAuth Login
Access the Kiro OAuth web interface at:
```
http://your-server:8080/v0/oauth/kiro
```
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
- AWS Builder ID login
- AWS Identity Center (IDC) login
- Token import from Kiro IDE
## Quick Deployment with Docker
### One-Command Deployment
```bash
# Create deployment directory
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# Create docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: 17600006524/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# Download example config
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
# Pull and start
docker compose pull && docker compose up -d
```
### Configuration
Edit `config.yaml` before starting:
```yaml
# Basic configuration example
server:
port: 8317
# Add your provider configurations here
```
### Update to Latest Version
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## Contributing
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.

View File

@@ -13,6 +13,82 @@
- 新增 GitHub Copilot 支持OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
## 新增功能 (Plus 增强版)
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
- **请求限流器**: 内置请求限流,防止 API 滥用
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
- **监控指标**: 请求指标收集,用于监控和调试
- **设备指纹**: 设备指纹生成,增强安全性
- **冷却管理**: 智能冷却机制,应对 API 速率限制
- **用量检查器**: 实时用量监控和配额管理
- **模型转换器**: 跨供应商的统一模型名称转换
- **UTF-8 流处理**: 改进的流式响应处理
## Kiro 认证
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:
```
http://your-server:8080/v0/oauth/kiro
```
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
- AWS Builder ID 登录
- AWS Identity Center (IDC) 登录
- 从 Kiro IDE 导入令牌
## Docker 快速部署
### 一键部署
```bash
# 创建部署目录
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# 创建 docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: 17600006524/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# 下载示例配置
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
# 拉取并启动
docker compose pull && docker compose up -d
```
### 配置说明
启动前请编辑 `config.yaml`
```yaml
# 基本配置示例
server:
port: 8317
# 在此添加你的供应商配置
```
### 更新到最新版本
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## 贡献
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。

View File

@@ -17,6 +17,7 @@ import (
"github.com/joho/godotenv"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -74,6 +75,7 @@ func main() {
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
@@ -96,6 +98,7 @@ func main() {
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
@@ -454,7 +457,8 @@ func main() {
// Create login options to be used in authentication flows.
options := &cmd.LoginOptions{
NoBrowser: noBrowser,
NoBrowser: noBrowser,
CallbackPort: oauthCallbackPort,
}
// Register the shared token store once so all components use the same persistence backend.
@@ -530,6 +534,13 @@ func main() {
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
// 初始化并启动 Kiro token 后台刷新
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
}

View File

@@ -90,6 +90,10 @@ nonstream-keepalive-interval: 0
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# When true, enable official Codex instructions injection for Codex API requests.
# When false (default), CodexInstructionsForModel returns immediately without modification.
codex-instructions-enabled: false
# Gemini API keys
# gemini-api-key:
# - api-key: "AIzaSy...01"
@@ -142,6 +146,15 @@ nonstream-keepalive-interval: 0
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
# cloak: # optional: request cloaking for non-Claude-Code clients
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
# # "always": always apply cloaking
# # "never": never apply cloaking
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
# # true: strip all user system messages, keep only Claude Code prompt
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
@@ -216,12 +229,27 @@ nonstream-keepalive-interval: 0
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name mappings (per channel)
# These mappings rename model IDs for both model listing and request routing.
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-mappings:
#oauth-model-alias:
# antigravity:
# - name: "rev19-uic3-1p"
# alias: "gemini-2.5-computer-use-preview-10-2025"
# - name: "gemini-3-pro-image"
# alias: "gemini-3-pro-image-preview"
# - name: "gemini-3-pro-high"
# alias: "gemini-3-pro-preview"
# - name: "gemini-3-flash"
# alias: "gemini-3-flash-preview"
# - name: "claude-sonnet-4-5"
# alias: "gemini-claude-sonnet-4-5"
# - name: "claude-sonnet-4-5-thinking"
# alias: "gemini-claude-sonnet-4-5-thinking"
# - name: "claude-opus-4-5-thinking"
# alias: "gemini-claude-opus-4-5-thinking"
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
@@ -289,9 +317,21 @@ nonstream-keepalive-interval: 0
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> value
# "generationConfig.thinkingConfig.thinkingBudget": 32768
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
# override: # Override rules always set parameters, overwriting any existing values.
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> value
# "reasoning.effort": "high"
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"

View File

@@ -22,7 +22,7 @@ services:
- "51121:51121"
- "11451:11451"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
restart: unless-stopped

View File

@@ -3,6 +3,7 @@ package management
import (
"bytes"
"context"
"encoding/hex"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
@@ -24,6 +25,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
@@ -1387,9 +1389,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
planType := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
}
hashAccountID := ""
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
// Build bundle compatible with existing storage
bundle := &codex.CodexAuthBundle{
@@ -1406,10 +1415,11 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
// Create token storage and persist
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
record := &coreauth.Auth{
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
ID: fileName,
Provider: "codex",
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": tokenStorage.Email,
@@ -1707,7 +1717,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
// Create token storage
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli())
record := &coreauth.Auth{
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
Provider: "qwen",
@@ -1812,7 +1822,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
tokenStorage := authSvc.CreateTokenStorage(tokenData)
identifier := strings.TrimSpace(tokenStorage.Email)
if identifier == "" {
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
identifier = fmt.Sprintf("%d", time.Now().UnixMilli())
tokenStorage.Email = identifier
}
record := &coreauth.Auth{
@@ -1843,6 +1853,89 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitHubToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing GitHub Copilot authentication...")
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
deviceCode, err := deviceClient.RequestDeviceCode(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
tokenData, errPoll := deviceClient.PollForToken(ctx, deviceCode)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPoll)
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
username = "github-user"
}
tokenStorage := &copilot.CopilotTokenStorage{
AccessToken: tokenData.AccessToken,
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"username": username,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authURL,
"state": state,
"user_code": userCode,
"verification_uri": authURL,
})
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
@@ -1897,15 +1990,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
fileName := iflowauth.SanitizeIFlowFileName(email)
if fileName == "" {
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
} else {
fileName = fmt.Sprintf("iflow-%s", fileName)
}
tokenStorage.Email = email
timestamp := time.Now().Unix()
record := &coreauth.Auth{
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
ID: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Provider: "iflow",
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
FileName: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Storage: tokenStorage,
Metadata: map[string]any{
"email": email,
@@ -2112,7 +2207,20 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// For free users, use backend project ID for preview model access
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
finalProjectID = responseProjectID
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}

View File

@@ -703,21 +703,21 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
h.persist(c)
}
// oauth-model-mappings: map[string][]ModelNameMapping
func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
// oauth-model-alias: map[string][]OAuthModelAlias
func (h *Handler) GetOAuthModelAlias(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)})
}
func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
func (h *Handler) PutOAuthModelAlias(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.ModelNameMapping
var entries map[string][]config.OAuthModelAlias
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.ModelNameMapping `json:"items"`
Items map[string][]config.OAuthModelAlias `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
@@ -725,15 +725,15 @@ func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
}
entries = wrapper.Items
}
h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Mappings []config.ModelNameMapping `json:"mappings"`
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Aliases []config.OAuthModelAlias `json:"aliases"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
@@ -751,32 +751,32 @@ func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
return
}
normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelMappings == nil {
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelMappings == nil {
h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping)
if h.cfg.OAuthModelAlias == nil {
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
h.cfg.OAuthModelMappings[channel] = normalized
h.cfg.OAuthModelAlias[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
@@ -785,17 +785,17 @@ func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelMappings == nil {
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelMappings[channel]; !ok {
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelMappings, channel)
if len(h.cfg.OAuthModelMappings) == 0 {
h.cfg.OAuthModelMappings = nil
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
}
@@ -1042,26 +1042,26 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
entry.Models = normalized
}
func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string][]config.OAuthModelAlias {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.ModelNameMapping, len(entries))
for channel, mappings := range entries {
if len(mappings) == 0 {
copied := make(map[string][]config.OAuthModelAlias, len(entries))
for channel, aliases := range entries {
if len(aliases) == 0 {
continue
}
copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
copied[channel] = append([]config.OAuthModelAlias(nil), aliases...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelMappings: copied}
cfg.SanitizeOAuthModelMappings()
if len(cfg.OAuthModelMappings) == 0 {
cfg := config.Config{OAuthModelAlias: copied}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias) == 0 {
return nil
}
return cfg.OAuthModelMappings
return cfg.OAuthModelAlias
}
// GetAmpCode returns the complete ampcode configuration.

View File

@@ -13,7 +13,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const (
@@ -360,16 +360,7 @@ func (h *Handler) logDirectory() string {
if h.logDir != "" {
return h.logDir
}
if base := util.WritablePath(); base != "" {
return filepath.Join(base, "logs")
}
if h.configFilePath != "" {
dir := filepath.Dir(h.configFilePath)
if dir != "" && dir != "." {
return filepath.Join(dir, "logs")
}
}
return "logs"
return logging.ResolveLogDirectory(h.cfg)
}
func (h *Handler) collectLogFiles(dir string) ([]string, error) {

View File

@@ -238,6 +238,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "qwen", nil
case "kiro":
return "kiro", nil
case "github":
return "github", nil
default:
return "", errUnsupportedOAuthFlow
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -134,10 +135,11 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
}
// Normalize model (handles dynamic thinking suffixes)
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
suffixResult := thinking.ParseSuffix(modelName)
normalizedModel := suffixResult.ModelName
thinkingSuffix := ""
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
thinkingSuffix = modelName[len(normalizedModel):]
if suffixResult.HasSuffix {
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
resolveMappedModel := func() (string, []string) {
@@ -157,13 +159,13 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
if mappedThinkingMetadata == nil {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
}
}
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil

View File

@@ -8,6 +8,7 @@ import (
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
@@ -44,6 +45,11 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
//
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
// However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
if requestedModel == "" {
return ""
@@ -52,16 +58,20 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
m.mu.RLock()
defer m.mu.RUnlock()
// Normalize the requested model for lookup
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
// Normalize the base model for lookup (case-insensitive)
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
// Check for direct mapping using base model name
targetModel, exists := m.mappings[normalizedBase]
if !exists {
// Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
// Try regex mappings in order using base model only
// (suffix is handled separately via ParseSuffix)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
if rm.re.MatchString(baseModel) {
targetModel = rm.to
exists = true
break
@@ -72,14 +82,28 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
}
}
// Verify target model has available providers
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
providers := util.GetProviderName(normalizedTarget)
// Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel)
// Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetResult.ModelName)
if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return ""
}
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
if targetResult.HasSuffix {
// Config's "to" already contains a suffix - use it as-is (config priority)
return targetModel
}
// Preserve user's thinking suffix on the mapped model
// (skip empty suffixes to avoid returning "model()")
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return targetModel + "(" + requestResult.RawSuffix + ")"
}
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
return targetModel
}

View File

@@ -217,10 +217,10 @@ func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix but should match base via regex
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
if result != "gemini-2.5-pro(high)" {
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
}
}
@@ -281,3 +281,95 @@ func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_SuffixPreservation(t *testing.T) {
reg := registry.GetGlobalRegistry()
// Register test models
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-suffix")
defer reg.UnregisterClient("test-client-suffix-2")
tests := []struct {
name string
mappings []config.AmpModelMapping
input string
want string
}{
{
name: "numeric suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "level suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "no suffix unchanged",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p",
want: "gemini-2.5-pro",
},
{
name: "config suffix takes priority",
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
input: "alias(high)",
want: "gemini-2.5-pro(medium)",
},
{
name: "regex with suffix preserved",
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "auto suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(auto)",
want: "gemini-2.5-pro(auto)",
},
{
name: "none suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(none)",
want: "gemini-2.5-pro(none)",
},
{
name: "case insensitive base lookup with suffix",
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "empty suffix filtered out",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p()",
want: "gemini-2.5-pro",
},
{
name: "incomplete suffix treated as no suffix",
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
input: "g25p(high",
want: "gemini-2.5-pro",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mapper := NewModelMapper(tt.mappings)
got := mapper.MapModel(tt.input)
if got != tt.want {
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

View File

@@ -23,9 +23,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
@@ -254,15 +256,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
managementasset.SetCurrentConfig(cfg)
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
// Initialize management handler
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
if optionState.localPassword != "" {
s.mgmt.SetLocalPassword(optionState.localPassword)
}
logDir := filepath.Join(s.currentPath, "logs")
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
s.localPassword = optionState.localPassword
@@ -293,6 +293,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.registerManagementRoutes()
}
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
kiroOAuthHandler.RegisterRoutes(engine)
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
@@ -621,10 +626,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings)
mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings)
mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings)
mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings)
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
@@ -641,6 +646,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
@@ -933,6 +939,16 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
}
}
if oldCfg == nil || oldCfg.CodexInstructionsEnabled != cfg.CodexInstructionsEnabled {
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
if oldCfg != nil {
log.Debugf("codex_instructions_enabled updated from %t to %t", oldCfg.CodexInstructionsEnabled, cfg.CodexInstructionsEnabled)
} else {
log.Debugf("codex_instructions_enabled toggled to %t", cfg.CodexInstructionsEnabled)
}
}
if s.handlers != nil && s.handlers.AuthManager != nil {
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
}

View File

@@ -0,0 +1,57 @@
package codex
import (
"fmt"
"strings"
"unicode"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
// When planType is available (e.g. "plus", "team"), it is appended after the email
// as a suffix to disambiguate subscriptions.
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
email = strings.TrimSpace(email)
plan := normalizePlanTypeForFilename(planType)
prefix := ""
if includeProviderPrefix {
prefix = "codex"
}
if plan == "" {
return fmt.Sprintf("%s-%s.json", prefix, email)
} else if plan == "team" {
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
}
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
}
func normalizePlanTypeForFilename(planType string) string {
planType = strings.TrimSpace(planType)
if planType == "" {
return ""
}
parts := strings.FieldsFunc(planType, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
})
if len(parts) == 0 {
return ""
}
for i, part := range parts {
parts[i] = titleToken(part)
}
return strings.Join(parts, "-")
}
func titleToken(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
return cases.Title(language.English).String(token)
}

View File

@@ -29,8 +29,9 @@ import (
)
const (
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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"
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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"
geminiDefaultCallbackPort = 8085
)
var (
@@ -49,8 +50,9 @@ type GeminiAuth struct {
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
Prompt func(string) (string, error)
NoBrowser bool
CallbackPort int
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth.
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
// - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := geminiDefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
@@ -106,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
conf := &oauth2.Config{
ClientID: geminiOauthClientID,
ClientSecret: geminiOauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
RedirectURL: callbackURL, // This will be used by the local server.
Scopes: geminiOauthScopes,
Endpoint: google.Endpoint,
}
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := geminiDefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback"
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
config.RedirectURL = callbackURL
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Check if browser is available
if !browser.IsAvailable() {
log.Warn("No browser available on this system")
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
} else {
if err := browser.OpenURL(authURL); err != nil {
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
log.Warn(codex.GetUserFriendlyMessage(authErr))
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
// Log platform info for debugging
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
}
}
} else {
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
}

View File

@@ -5,10 +5,12 @@ package kiro
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
@@ -85,6 +87,87 @@ type KiroModel struct {
// KiroIDETokenFile is the default path to Kiro IDE's token file
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
// Default retry configuration for file reading
const (
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
)
// isTransientFileError checks if the error is a transient file access error
// that may be resolved by retrying (e.g., file locked by another process on Windows).
func isTransientFileError(err error) bool {
if err == nil {
return false
}
// Check for OS-level file access errors (Windows sharing violation, etc.)
var pathErr *os.PathError
if errors.As(err, &pathErr) {
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
errStr := pathErr.Err.Error()
if strings.Contains(errStr, "being used by another process") ||
strings.Contains(errStr, "sharing violation") ||
strings.Contains(errStr, "lock violation") {
return true
}
}
// Check error message for common transient patterns
errMsg := strings.ToLower(err.Error())
transientPatterns := []string{
"being used by another process",
"sharing violation",
"lock violation",
"access is denied",
"unexpected end of json",
"unexpected eof",
}
for _, pattern := range transientPatterns {
if strings.Contains(errMsg, pattern) {
return true
}
}
return false
}
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
if maxAttempts <= 0 {
maxAttempts = defaultTokenReadMaxAttempts
}
if baseDelay <= 0 {
baseDelay = defaultTokenReadBaseDelay
}
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
token, err := LoadKiroIDEToken()
if err == nil {
return token, nil
}
lastErr = err
// Only retry for transient errors
if !isTransientFileError(err) {
return nil, err
}
// Exponential backoff: delay * 2^attempt, capped at 500ms
delay := baseDelay * time.Duration(1<<uint(attempt))
if delay > 500*time.Millisecond {
delay = 500 * time.Millisecond
}
time.Sleep(delay)
}
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
}
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
func LoadKiroIDEToken() (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()

View File

@@ -280,6 +280,11 @@ func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorag
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
}
@@ -311,4 +316,19 @@ func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *Kiro
storage.AuthMethod = tokenData.AuthMethod
storage.Provider = tokenData.Provider
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ClientID != "" {
storage.ClientID = tokenData.ClientID
}
if tokenData.ClientSecret != "" {
storage.ClientSecret = tokenData.ClientSecret
}
if tokenData.Region != "" {
storage.Region = tokenData.Region
}
if tokenData.StartURL != "" {
storage.StartURL = tokenData.StartURL
}
if tokenData.Email != "" {
storage.Email = tokenData.Email
}
}

View File

@@ -0,0 +1,224 @@
package kiro
import (
"context"
"log"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"golang.org/x/sync/semaphore"
)
type Token struct {
ID string
AccessToken string
RefreshToken string
ExpiresAt time.Time
LastVerified time.Time
ClientID string
ClientSecret string
AuthMethod string
Provider string
StartURL string
Region string
}
type TokenRepository interface {
FindOldestUnverified(limit int) []*Token
UpdateToken(token *Token) error
}
type RefresherOption func(*BackgroundRefresher)
func WithInterval(interval time.Duration) RefresherOption {
return func(r *BackgroundRefresher) {
r.interval = interval
}
}
func WithBatchSize(size int) RefresherOption {
return func(r *BackgroundRefresher) {
r.batchSize = size
}
}
func WithConcurrency(concurrency int) RefresherOption {
return func(r *BackgroundRefresher) {
r.concurrency = concurrency
}
}
type BackgroundRefresher struct {
interval time.Duration
batchSize int
concurrency int
tokenRepo TokenRepository
stopCh chan struct{}
wg sync.WaitGroup
oauth *KiroOAuth
ssoClient *SSOOIDCClient
callbackMu sync.RWMutex // 保护回调函数的并发访问
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
}
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
r := &BackgroundRefresher{
interval: time.Minute,
batchSize: 50,
concurrency: 10,
tokenRepo: repo,
stopCh: make(chan struct{}),
oauth: nil, // Lazy init - will be set when config available
ssoClient: nil, // Lazy init - will be set when config available
}
for _, opt := range opts {
opt(r)
}
return r
}
// WithConfig sets the configuration for OAuth and SSO clients.
func WithConfig(cfg *config.Config) RefresherOption {
return func(r *BackgroundRefresher) {
r.oauth = NewKiroOAuth(cfg)
r.ssoClient = NewSSOOIDCClient(cfg)
}
}
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
// The callback receives the token ID (filename) and the new token data.
// This allows external components (e.g., Watcher) to be notified of token updates.
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
return func(r *BackgroundRefresher) {
r.callbackMu.Lock()
r.onTokenRefreshed = callback
r.callbackMu.Unlock()
}
}
func (r *BackgroundRefresher) Start(ctx context.Context) {
r.wg.Add(1)
go func() {
defer r.wg.Done()
ticker := time.NewTicker(r.interval)
defer ticker.Stop()
r.refreshBatch(ctx)
for {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-ticker.C:
r.refreshBatch(ctx)
}
}
}()
}
func (r *BackgroundRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
}
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
if len(tokens) == 0 {
return
}
sem := semaphore.NewWeighted(int64(r.concurrency))
var wg sync.WaitGroup
for i, token := range tokens {
if i > 0 {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-time.After(100 * time.Millisecond):
}
}
if err := sem.Acquire(ctx, 1); err != nil {
return
}
wg.Add(1)
go func(t *Token) {
defer wg.Done()
defer sem.Release(1)
r.refreshSingle(ctx, t)
}(token)
}
wg.Wait()
}
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
var newTokenData *KiroTokenData
var err error
switch token.AuthMethod {
case "idc":
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
token.Region,
token.StartURL,
)
case "builder-id":
newTokenData, err = r.ssoClient.RefreshToken(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
)
default:
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
}
if err != nil {
log.Printf("failed to refresh token %s: %v", token.ID, err)
return
}
token.AccessToken = newTokenData.AccessToken
token.RefreshToken = newTokenData.RefreshToken
token.LastVerified = time.Now()
if newTokenData.ExpiresAt != "" {
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
token.ExpiresAt = expTime
}
}
if err := r.tokenRepo.UpdateToken(token); err != nil {
log.Printf("failed to update token %s: %v", token.ID, err)
return
}
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
r.callbackMu.RLock()
callback := r.onTokenRefreshed
r.callbackMu.RUnlock()
if callback != nil {
// 使用 defer recover 隔离回调 panic防止崩溃整个进程
func() {
defer func() {
if rec := recover(); rec != nil {
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
}
}()
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
callback(token.ID, newTokenData)
}()
}
}

View File

@@ -0,0 +1,112 @@
package kiro
import (
"sync"
"time"
)
const (
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
CooldownReasonQuotaExhausted = "quota_exhausted"
DefaultShortCooldown = 1 * time.Minute
MaxShortCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
)
type CooldownManager struct {
mu sync.RWMutex
cooldowns map[string]time.Time
reasons map[string]string
}
func NewCooldownManager() *CooldownManager {
return &CooldownManager{
cooldowns: make(map[string]time.Time),
reasons: make(map[string]string),
}
}
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.cooldowns[tokenKey] = time.Now().Add(duration)
cm.reasons[tokenKey] = reason
}
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return false
}
return time.Now().Before(endTime)
}
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return 0
}
remaining := time.Until(endTime)
if remaining < 0 {
return 0
}
return remaining
}
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.reasons[tokenKey]
}
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
cm.mu.Lock()
defer cm.mu.Unlock()
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
func (cm *CooldownManager) CleanupExpired() {
cm.mu.Lock()
defer cm.mu.Unlock()
now := time.Now()
for tokenKey, endTime := range cm.cooldowns {
if now.After(endTime) {
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
}
}
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cm.CleanupExpired()
case <-stopCh:
return
}
}
}
func CalculateCooldownFor429(retryCount int) time.Duration {
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
if duration > MaxShortCooldown {
return MaxShortCooldown
}
return duration
}
func CalculateCooldownUntilNextDay() time.Duration {
now := time.Now()
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
return time.Until(nextDay)
}

View File

@@ -0,0 +1,240 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewCooldownManager(t *testing.T) {
cm := NewCooldownManager()
if cm == nil {
t.Fatal("expected non-nil CooldownManager")
}
if cm.cooldowns == nil {
t.Error("expected non-nil cooldowns map")
}
if cm.reasons == nil {
t.Error("expected non-nil reasons map")
}
}
func TestSetCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
if !cm.IsInCooldown("token1") {
t.Error("expected token to be in cooldown")
}
if cm.GetCooldownReason("token1") != CooldownReason429 {
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
}
}
func TestIsInCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
if cm.IsInCooldown("nonexistent") {
t.Error("expected non-existent token to not be in cooldown")
}
}
func TestIsInCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
if cm.IsInCooldown("token1") {
t.Error("expected expired cooldown to return false")
}
}
func TestGetRemainingCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
remaining := cm.GetRemainingCooldown("token1")
if remaining <= 0 || remaining > 1*time.Second {
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
}
}
func TestGetRemainingCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
remaining := cm.GetRemainingCooldown("nonexistent")
if remaining != 0 {
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
}
}
func TestGetRemainingCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
remaining := cm.GetRemainingCooldown("token1")
if remaining != 0 {
t.Errorf("expected 0 remaining for expired, got %v", remaining)
}
}
func TestGetCooldownReason(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
}
}
func TestGetCooldownReason_NotSet(t *testing.T) {
cm := NewCooldownManager()
reason := cm.GetCooldownReason("nonexistent")
if reason != "" {
t.Errorf("expected empty reason for non-existent, got %s", reason)
}
}
func TestClearCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
cm.ClearCooldown("token1")
if cm.IsInCooldown("token1") {
t.Error("expected cooldown to be cleared")
}
if cm.GetCooldownReason("token1") != "" {
t.Error("expected reason to be cleared")
}
}
func TestClearCooldown_NonExistent(t *testing.T) {
cm := NewCooldownManager()
cm.ClearCooldown("nonexistent")
}
func TestCleanupExpired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
time.Sleep(10 * time.Millisecond)
cm.CleanupExpired()
if cm.GetCooldownReason("expired1") != "" {
t.Error("expected expired1 to be cleaned up")
}
if cm.GetCooldownReason("expired2") != "" {
t.Error("expected expired2 to be cleaned up")
}
if cm.GetCooldownReason("active") != CooldownReason429 {
t.Error("expected active to remain")
}
}
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
duration := CalculateCooldownFor429(0)
if duration != DefaultShortCooldown {
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
}
}
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
d1 := CalculateCooldownFor429(1)
d2 := CalculateCooldownFor429(2)
if d2 <= d1 {
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
}
}
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
duration := CalculateCooldownFor429(10)
if duration > MaxShortCooldown {
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
}
}
func TestCalculateCooldownUntilNextDay(t *testing.T) {
duration := CalculateCooldownUntilNextDay()
if duration <= 0 || duration > 24*time.Hour {
t.Errorf("expected duration between 0 and 24h, got %v", duration)
}
}
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
cm := NewCooldownManager()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
case 1:
cm.IsInCooldown(tokenKey)
case 2:
cm.GetRemainingCooldown(tokenKey)
case 3:
cm.GetCooldownReason(tokenKey)
case 4:
cm.ClearCooldown(tokenKey)
case 5:
cm.CleanupExpired()
}
}
}(i)
}
wg.Wait()
}
func TestCooldownReasonConstants(t *testing.T) {
if CooldownReason429 != "rate_limit_exceeded" {
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
}
if CooldownReasonSuspended != "account_suspended" {
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
}
if CooldownReasonQuotaExhausted != "quota_exhausted" {
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
}
}
func TestDefaultConstants(t *testing.T) {
if DefaultShortCooldown != 1*time.Minute {
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
}
if MaxShortCooldown != 5*time.Minute {
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
}
if LongCooldown != 24*time.Hour {
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
}
}
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
}
remaining := cm.GetRemainingCooldown("token1")
if remaining > 1*time.Minute {
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
}
}

View File

@@ -0,0 +1,197 @@
package kiro
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"sync"
"time"
)
// Fingerprint 多维度指纹信息
type Fingerprint struct {
SDKVersion string // 1.0.20-1.0.27
OSType string // darwin/windows/linux
OSVersion string // 10.0.22621
NodeVersion string // 18.x/20.x/22.x
KiroVersion string // 0.3.x-0.8.x
KiroHash string // SHA256
AcceptLanguage string
ScreenResolution string // 1920x1080
ColorDepth int // 24
HardwareConcurrency int // CPU 核心数
TimezoneOffset int
}
// FingerprintManager 指纹管理器
type FingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
rng *rand.Rand
}
var (
sdkVersions = []string{
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
}
osTypes = []string{"darwin", "windows", "linux"}
osVersions = map[string][]string{
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
}
nodeVersions = []string{
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
}
kiroVersions = []string{
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
}
acceptLanguages = []string{
"en-US,en;q=0.9",
"en-GB,en;q=0.9",
"zh-CN,zh;q=0.9,en;q=0.8",
"zh-TW,zh;q=0.9,en;q=0.8",
"ja-JP,ja;q=0.9,en;q=0.8",
"ko-KR,ko;q=0.9,en;q=0.8",
"de-DE,de;q=0.9,en;q=0.8",
"fr-FR,fr;q=0.9,en;q=0.8",
}
screenResolutions = []string{
"1920x1080", "2560x1440", "3840x2160",
"1366x768", "1440x900", "1680x1050",
"2560x1600", "3440x1440",
}
colorDepths = []int{24, 32}
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
)
// NewFingerprintManager 创建指纹管理器
func NewFingerprintManager() *FingerprintManager {
return &FingerprintManager{
fingerprints: make(map[string]*Fingerprint),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// GetFingerprint 获取或生成 Token 关联的指纹
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
fm.mu.RLock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
fm.mu.RUnlock()
return fp
}
fm.mu.RUnlock()
fm.mu.Lock()
defer fm.mu.Unlock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
return fp
}
fp := fm.generateFingerprint(tokenKey)
fm.fingerprints[tokenKey] = fp
return fp
}
// generateFingerprint 生成新的指纹
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
osType := fm.randomChoice(osTypes)
osVersion := fm.randomChoice(osVersions[osType])
kiroVersion := fm.randomChoice(kiroVersions)
fp := &Fingerprint{
SDKVersion: fm.randomChoice(sdkVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: fm.randomChoice(nodeVersions),
KiroVersion: kiroVersion,
AcceptLanguage: fm.randomChoice(acceptLanguages),
ScreenResolution: fm.randomChoice(screenResolutions),
ColorDepth: fm.randomIntChoice(colorDepths),
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
}
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
return fp
}
// generateKiroHash 生成 Kiro Hash
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
// randomChoice 随机选择字符串
func (fm *FingerprintManager) randomChoice(choices []string) string {
return choices[fm.rng.Intn(len(choices))]
}
// randomIntChoice 随机选择整数
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
return choices[fm.rng.Intn(len(choices))]
}
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
req.Header.Set("Accept-Language", fp.AcceptLanguage)
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
}
// RemoveFingerprint 移除 Token 关联的指纹
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
fm.mu.Lock()
defer fm.mu.Unlock()
delete(fm.fingerprints, tokenKey)
}
// Count 返回当前管理的指纹数量
func (fm *FingerprintManager) Count() int {
fm.mu.RLock()
defer fm.mu.RUnlock()
return len(fm.fingerprints)
}
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.SDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.SDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildAmzUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.SDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}

View File

@@ -0,0 +1,227 @@
package kiro
import (
"net/http"
"sync"
"testing"
)
func TestNewFingerprintManager(t *testing.T) {
fm := NewFingerprintManager()
if fm == nil {
t.Fatal("expected non-nil FingerprintManager")
}
if fm.fingerprints == nil {
t.Error("expected non-nil fingerprints map")
}
if fm.rng == nil {
t.Error("expected non-nil rng")
}
}
func TestGetFingerprint_NewToken(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if fp == nil {
t.Fatal("expected non-nil Fingerprint")
}
if fp.SDKVersion == "" {
t.Error("expected non-empty SDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.OSVersion == "" {
t.Error("expected non-empty OSVersion")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
if fp.KiroVersion == "" {
t.Error("expected non-empty KiroVersion")
}
if fp.KiroHash == "" {
t.Error("expected non-empty KiroHash")
}
if fp.AcceptLanguage == "" {
t.Error("expected non-empty AcceptLanguage")
}
if fp.ScreenResolution == "" {
t.Error("expected non-empty ScreenResolution")
}
if fp.ColorDepth == 0 {
t.Error("expected non-zero ColorDepth")
}
if fp.HardwareConcurrency == 0 {
t.Error("expected non-zero HardwareConcurrency")
}
}
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1 != fp2 {
t.Error("expected same fingerprint for same token")
}
}
func TestGetFingerprint_DifferentTokens(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token2")
if fp1 == fp2 {
t.Error("expected different fingerprints for different tokens")
}
}
func TestRemoveFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fm.GetFingerprint("token1")
if fm.Count() != 1 {
t.Fatalf("expected count 1, got %d", fm.Count())
}
fm.RemoveFingerprint("token1")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestRemoveFingerprint_NonExistent(t *testing.T) {
fm := NewFingerprintManager()
fm.RemoveFingerprint("nonexistent")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestCount(t *testing.T) {
fm := NewFingerprintManager()
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
fm.GetFingerprint("token1")
fm.GetFingerprint("token2")
fm.GetFingerprint("token3")
if fm.Count() != 3 {
t.Errorf("expected count 3, got %d", fm.Count())
}
}
func TestApplyToRequest(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
t.Error("X-Kiro-SDK-Version header mismatch")
}
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
t.Error("X-Kiro-OS-Type header mismatch")
}
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
t.Error("X-Kiro-OS-Version header mismatch")
}
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
t.Error("X-Kiro-Node-Version header mismatch")
}
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
t.Error("X-Kiro-Version header mismatch")
}
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
t.Error("X-Kiro-Hash header mismatch")
}
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
t.Error("Accept-Language header mismatch")
}
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
t.Error("X-Screen-Resolution header mismatch")
}
}
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
fm := NewFingerprintManager()
for i := 0; i < 20; i++ {
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
validVersions := osVersions[fp.OSType]
found := false
for _, v := range validVersions {
if v == fp.OSVersion {
found = true
break
}
}
if !found {
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
}
}
}
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
fm := NewFingerprintManager()
const numGoroutines = 100
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
tokenKey := "token" + string(rune('a'+id%26))
switch j % 4 {
case 0:
fm.GetFingerprint(tokenKey)
case 1:
fm.Count()
case 2:
fp := fm.GetFingerprint(tokenKey)
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
case 3:
fm.RemoveFingerprint(tokenKey)
}
}
}(i)
}
wg.Wait()
}
func TestKiroHashUniqueness(t *testing.T) {
fm := NewFingerprintManager()
hashes := make(map[string]bool)
for i := 0; i < 100; i++ {
fp := fm.GetFingerprint("token" + string(rune(i)))
if hashes[fp.KiroHash] {
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
}
hashes[fp.KiroHash] = true
}
}
func TestKiroHashFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if len(fp.KiroHash) != 64 {
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
}
for _, c := range fp.KiroHash {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("invalid hex character in KiroHash: %c", c)
}
}
}

View File

@@ -0,0 +1,174 @@
package kiro
import (
"math/rand"
"sync"
"time"
)
// Jitter configuration constants
const (
// JitterPercent is the default percentage of jitter to apply (±30%)
JitterPercent = 0.30
// Human-like delay ranges
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
LongDelayMin = 5 * time.Second // Minimum for reading/resting
LongDelayMax = 10 * time.Second // Maximum for reading/resting
// Probability thresholds for human-like behavior
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
)
var (
jitterRand *rand.Rand
jitterRandOnce sync.Once
jitterMu sync.Mutex
lastRequestTime time.Time
)
// initJitterRand initializes the random number generator for jitter calculations.
// Uses a time-based seed for unpredictable but reproducible randomness.
func initJitterRand() {
jitterRandOnce.Do(func() {
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
})
}
// RandomDelay generates a random delay between min and max duration.
// Thread-safe implementation using mutex protection.
func RandomDelay(min, max time.Duration) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if min >= max {
return min
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// JitterDelay adds jitter to a base delay.
// Applies ±jitterPercent variation to the base delay.
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if jitterPercent <= 0 || jitterPercent > 1 {
jitterPercent = JitterPercent
}
// Calculate jitter range: base * jitterPercent
jitterRange := float64(baseDelay) * jitterPercent
// Generate random value in range [-jitterRange, +jitterRange]
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
result := time.Duration(float64(baseDelay) + jitter)
if result < 0 {
return 0
}
return result
}
// JitterDelayDefault applies the default ±30% jitter to a base delay.
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
return JitterDelay(baseDelay, JitterPercent)
}
// HumanLikeDelay generates a delay that mimics human behavior patterns.
// The delay is selected based on probability distribution:
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
//
// Returns the delay duration (caller should call time.Sleep with this value).
func HumanLikeDelay() time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
// Track time since last request for adaptive behavior
now := time.Now()
timeSinceLastRequest := now.Sub(lastRequestTime)
lastRequestTime = now
// If requests are very close together, use short delay
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
}
// Otherwise, use probability-based selection
roll := jitterRand.Float64()
var min, max time.Duration
switch {
case roll < ShortDelayProbability:
// Short delay - consecutive operations
min, max = ShortDelayMin, ShortDelayMax
case roll < ShortDelayProbability+LongDelayProbability:
// Long delay - reading/resting
min, max = LongDelayMin, LongDelayMax
default:
// Normal delay - thinking time
min, max = NormalDelayMin, NormalDelayMax
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// ApplyHumanLikeDelay applies human-like delay by sleeping.
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
func ApplyHumanLikeDelay() {
delay := HumanLikeDelay()
if delay > 0 {
time.Sleep(delay)
}
}
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
// Calculate exponential backoff: baseDelay * 2^attempt
backoff := baseDelay * time.Duration(1<<uint(attempt))
if backoff > maxDelay {
backoff = maxDelay
}
// Add ±30% jitter
return JitterDelay(backoff, JitterPercent)
}
// ShouldSkipDelay determines if delay should be skipped based on context.
// Returns true for streaming responses, WebSocket connections, etc.
// This function can be extended to check additional skip conditions.
func ShouldSkipDelay(isStreaming bool) bool {
return isStreaming
}
// ResetLastRequestTime resets the last request time tracker.
// Useful for testing or when starting a new session.
func ResetLastRequestTime() {
jitterMu.Lock()
defer jitterMu.Unlock()
lastRequestTime = time.Time{}
}

View File

@@ -0,0 +1,187 @@
package kiro
import (
"math"
"sync"
"time"
)
// TokenMetrics holds performance metrics for a single token.
type TokenMetrics struct {
SuccessRate float64 // Success rate (0.0 - 1.0)
AvgLatency float64 // Average latency in milliseconds
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
LastUsed time.Time // Last usage timestamp
FailCount int // Consecutive failure count
TotalRequests int // Total request count
successCount int // Internal: successful request count
totalLatency float64 // Internal: cumulative latency
}
// TokenScorer manages token metrics and scoring.
type TokenScorer struct {
mu sync.RWMutex
metrics map[string]*TokenMetrics
// Scoring weights
successRateWeight float64
quotaWeight float64
latencyWeight float64
lastUsedWeight float64
failPenaltyMultiplier float64
}
// NewTokenScorer creates a new TokenScorer with default weights.
func NewTokenScorer() *TokenScorer {
return &TokenScorer{
metrics: make(map[string]*TokenMetrics),
successRateWeight: 0.4,
quotaWeight: 0.25,
latencyWeight: 0.2,
lastUsedWeight: 0.15,
failPenaltyMultiplier: 0.1,
}
}
// getOrCreateMetrics returns existing metrics or creates new ones.
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
if m, ok := s.metrics[tokenKey]; ok {
return m
}
m := &TokenMetrics{
SuccessRate: 1.0,
QuotaRemaining: 1.0,
}
s.metrics[tokenKey] = m
return m
}
// RecordRequest records the result of a request for a token.
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.TotalRequests++
m.LastUsed = time.Now()
m.totalLatency += float64(latency.Milliseconds())
if success {
m.successCount++
m.FailCount = 0
} else {
m.FailCount++
}
// Update derived metrics
if m.TotalRequests > 0 {
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
}
}
// SetQuotaRemaining updates the remaining quota for a token.
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.QuotaRemaining = quota
}
// GetMetrics returns a copy of the metrics for a token.
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
s.mu.RLock()
defer s.mu.RUnlock()
if m, ok := s.metrics[tokenKey]; ok {
copy := *m
return &copy
}
return nil
}
// CalculateScore computes the score for a token (higher is better).
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
s.mu.RLock()
defer s.mu.RUnlock()
m, ok := s.metrics[tokenKey]
if !ok {
return 1.0 // New tokens get a high initial score
}
// Success rate component (0-1)
successScore := m.SuccessRate
// Quota component (0-1)
quotaScore := m.QuotaRemaining
// Latency component (normalized, lower is better)
// Using exponential decay: score = e^(-latency/1000)
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
if m.TotalRequests == 0 {
latencyScore = 1.0
}
// Last used component (prefer tokens not recently used)
// Score increases as time since last use increases
timeSinceUse := time.Since(m.LastUsed).Seconds()
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
if m.LastUsed.IsZero() {
lastUsedScore = 1.0
}
// Calculate weighted score
score := s.successRateWeight*successScore +
s.quotaWeight*quotaScore +
s.latencyWeight*latencyScore +
s.lastUsedWeight*lastUsedScore
// Apply consecutive failure penalty
if m.FailCount > 0 {
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
score = score * math.Max(0, 1.0-penalty)
}
return score
}
// SelectBestToken selects the token with the highest score.
func (s *TokenScorer) SelectBestToken(tokens []string) string {
if len(tokens) == 0 {
return ""
}
if len(tokens) == 1 {
return tokens[0]
}
bestToken := tokens[0]
bestScore := s.CalculateScore(tokens[0])
for _, token := range tokens[1:] {
score := s.CalculateScore(token)
if score > bestScore {
bestScore = score
bestToken = token
}
}
return bestToken
}
// ResetMetrics clears all metrics for a token.
func (s *TokenScorer) ResetMetrics(tokenKey string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.metrics, tokenKey)
}
// ResetAllMetrics clears all stored metrics.
func (s *TokenScorer) ResetAllMetrics() {
s.mu.Lock()
defer s.mu.Unlock()
s.metrics = make(map[string]*TokenMetrics)
}

View File

@@ -0,0 +1,301 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewTokenScorer(t *testing.T) {
s := NewTokenScorer()
if s == nil {
t.Fatal("expected non-nil TokenScorer")
}
if s.metrics == nil {
t.Error("expected non-nil metrics map")
}
if s.successRateWeight != 0.4 {
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
}
if s.quotaWeight != 0.25 {
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
}
}
func TestRecordRequest_Success(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m == nil {
t.Fatal("expected non-nil metrics")
}
if m.TotalRequests != 1 {
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
}
if m.SuccessRate != 1.0 {
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", m.FailCount)
}
if m.AvgLatency != 100 {
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
}
}
func TestRecordRequest_Failure(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", false, 200*time.Millisecond)
m := s.GetMetrics("token1")
if m.SuccessRate != 0.0 {
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
}
if m.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", m.FailCount)
}
}
func TestRecordRequest_MixedResults(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.TotalRequests != 4 {
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
}
if m.SuccessRate != 0.75 {
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
}
}
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.FailCount != 3 {
t.Errorf("expected FailCount 3, got %d", m.FailCount)
}
}
func TestSetQuotaRemaining(t *testing.T) {
s := NewTokenScorer()
s.SetQuotaRemaining("token1", 0.5)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 0.5 {
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
}
}
func TestGetMetrics_NonExistent(t *testing.T) {
s := NewTokenScorer()
m := s.GetMetrics("nonexistent")
if m != nil {
t.Error("expected nil metrics for non-existent token")
}
}
func TestGetMetrics_ReturnsCopy(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m1 := s.GetMetrics("token1")
m1.TotalRequests = 999
m2 := s.GetMetrics("token1")
if m2.TotalRequests == 999 {
t.Error("GetMetrics should return a copy")
}
}
func TestCalculateScore_NewToken(t *testing.T) {
s := NewTokenScorer()
score := s.CalculateScore("newtoken")
if score != 1.0 {
t.Errorf("expected score 1.0 for new token, got %f", score)
}
}
func TestCalculateScore_PerfectToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 50*time.Millisecond)
s.SetQuotaRemaining("token1", 1.0)
time.Sleep(100 * time.Millisecond)
score := s.CalculateScore("token1")
if score < 0.5 || score > 1.0 {
t.Errorf("expected high score for perfect token, got %f", score)
}
}
func TestCalculateScore_FailedToken(t *testing.T) {
s := NewTokenScorer()
for i := 0; i < 5; i++ {
s.RecordRequest("token1", false, 1000*time.Millisecond)
}
s.SetQuotaRemaining("token1", 0.1)
score := s.CalculateScore("token1")
if score > 0.5 {
t.Errorf("expected low score for failed token, got %f", score)
}
}
func TestCalculateScore_FailPenalty(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
scoreNoFail := s.CalculateScore("token1")
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
scoreWithFail := s.CalculateScore("token1")
if scoreWithFail >= scoreNoFail {
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
}
}
func TestSelectBestToken_Empty(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{})
if best != "" {
t.Errorf("expected empty string for empty tokens, got %s", best)
}
}
func TestSelectBestToken_SingleToken(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{"token1"})
if best != "token1" {
t.Errorf("expected token1, got %s", best)
}
}
func TestSelectBestToken_MultipleTokens(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.SetQuotaRemaining("bad", 0.1)
s.RecordRequest("good", true, 50*time.Millisecond)
s.SetQuotaRemaining("good", 0.9)
time.Sleep(50 * time.Millisecond)
best := s.SelectBestToken([]string{"bad", "good"})
if best != "good" {
t.Errorf("expected good token to be selected, got %s", best)
}
}
func TestResetMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.ResetMetrics("token1")
m := s.GetMetrics("token1")
if m != nil {
t.Error("expected nil metrics after reset")
}
}
func TestResetAllMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token2", true, 100*time.Millisecond)
s.RecordRequest("token3", true, 100*time.Millisecond)
s.ResetAllMetrics()
if s.GetMetrics("token1") != nil {
t.Error("expected nil metrics for token1 after reset all")
}
if s.GetMetrics("token2") != nil {
t.Error("expected nil metrics for token2 after reset all")
}
}
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
s := NewTokenScorer()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
case 1:
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
case 2:
s.GetMetrics(tokenKey)
case 3:
s.CalculateScore(tokenKey)
case 4:
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
case 5:
if j%20 == 0 {
s.ResetMetrics(tokenKey)
}
}
}
}(i)
}
wg.Wait()
}
func TestAvgLatencyCalculation(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 200*time.Millisecond)
s.RecordRequest("token1", true, 300*time.Millisecond)
m := s.GetMetrics("token1")
if m.AvgLatency != 200 {
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
}
}
func TestLastUsedUpdated(t *testing.T) {
s := NewTokenScorer()
before := time.Now()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.LastUsed.Before(before) {
t.Error("expected LastUsed to be after test start time")
}
if m.LastUsed.After(time.Now()) {
t.Error("expected LastUsed to be before or equal to now")
}
}
func TestDefaultQuotaForNewToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 1.0 {
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
}
}

View File

@@ -227,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
@@ -285,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}

View File

@@ -0,0 +1,982 @@
// Package kiro provides OAuth Web authentication for Kiro.
package kiro
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
defaultSessionExpiry = 10 * time.Minute
pollIntervalSeconds = 5
)
type authSessionStatus string
const (
statusPending authSessionStatus = "pending"
statusSuccess authSessionStatus = "success"
statusFailed authSessionStatus = "failed"
)
type webAuthSession struct {
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
}
type OAuthWebHandler struct {
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
}
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
return &OAuthWebHandler{
cfg: cfg,
sessions: make(map[string]*webAuthSession),
}
}
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
h.onTokenObtained = callback
}
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
oauth := router.Group("/v0/oauth/kiro")
{
oauth.GET("", h.handleSelect)
oauth.GET("/start", h.handleStart)
oauth.GET("/callback", h.handleCallback)
oauth.GET("/social/callback", h.handleSocialCallback)
oauth.GET("/status", h.handleStatus)
oauth.POST("/import", h.handleImportToken)
oauth.POST("/refresh", h.handleManualRefresh)
}
}
func generateStateID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
h.renderSelectPage(c)
}
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
method := c.Query("method")
if method == "" {
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
return
}
switch method {
case "google", "github":
// Google/GitHub social login is not supported for third-party apps
// due to AWS Cognito redirect_uri restrictions
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
case "builder-id":
h.startBuilderIDAuth(c)
case "idc":
h.startIDCAuth(c)
default:
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
}
}
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
h.renderError(c, "Failed to generate PKCE parameters")
return
}
socialClient := NewSocialAuthClient(h.cfg)
var provider string
if method == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
redirectURI := h.getSocialCallbackURL(c)
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
session := &webAuthSession{
stateID: stateID,
authMethod: method,
authURL: authURL,
status: statusPending,
startedAt: time.Now(),
expiresIn: 600,
codeVerifier: codeVerifier,
codeChallenge: codeChallenge,
region: "us-east-1",
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go func() {
<-ctx.Done()
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
}()
c.Redirect(http.StatusFound, authURL)
}
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
scheme := "http"
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
scheme = "https"
}
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
}
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
region := defaultIDCRegion
startURL := builderIDStartURL
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "builder-id",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
startURL := c.Query("startUrl")
region := c.Query("region")
if startURL == "" {
h.renderError(c, "Missing startUrl parameter for IDC authentication")
return
}
if region == "" {
region = defaultIDCRegion
}
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "idc",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
defer session.cancelFunc()
interval := time.Duration(session.interval) * time.Second
if interval < time.Duration(pollIntervalSeconds)*time.Second {
interval = time.Duration(pollIntervalSeconds) * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
return
case <-ticker.C:
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
ctx,
session.clientID,
session.clientSecret,
session.deviceCode,
session.region,
)
if err != nil {
errStr := err.Error()
if errStr == ErrAuthorizationPending.Error() {
continue
}
if errStr == ErrSlowDown.Error() {
interval += 5 * time.Second
ticker.Reset(interval)
continue
}
h.mu.Lock()
session.status = statusFailed
session.error = errStr
session.completedAt = time.Now()
h.mu.Unlock()
log.Errorf("OAuth Web: token polling failed: %v", err)
return
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: authentication successful for %s", email)
return
}
}
}
// saveTokenToFile saves the token data to the auth directory
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
// Get auth directory from config or use default
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
// Fall back to default location
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
log.Errorf("OAuth Web: failed to get home directory: %v", err)
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Create directory if not exists
if err := os.MkdirAll(authDir, 0700); err != nil {
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
return
}
// Generate filename based on auth method
// Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
if tokenData.Email != "" {
// Sanitize email for filename (replace @ and . with -)
sanitizedEmail := tokenData.Email
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
}
authFilePath := filepath.Join(authDir, fileName)
// Convert to storage format and save
storage := &KiroTokenStorage{
Type: "kiro",
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
ProfileArn: tokenData.ProfileArn,
ExpiresAt: tokenData.ExpiresAt,
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
if err := storage.SaveTokenToFile(authFilePath); err != nil {
log.Errorf("OAuth Web: failed to save token to file: %v", err)
return
}
log.Infof("OAuth Web: token saved to %s", authFilePath)
}
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
return session.ssoClient
}
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
stateID := c.Query("state")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.status == statusSuccess {
h.renderSuccess(c, session)
} else if session.status == statusFailed {
h.renderError(c, session.error)
} else {
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
}
}
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
stateID := c.Query("state")
code := c.Query("code")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
if code == "" {
h.renderError(c, "Missing authorization code")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.authMethod != "google" && session.authMethod != "github" {
h.renderError(c, "Invalid session type for social callback")
return
}
socialClient := NewSocialAuthClient(h.cfg)
redirectURI := h.getSocialCallbackURL(c)
tokenReq := &CreateTokenRequest{
Code: code,
CodeVerifier: session.codeVerifier,
RedirectURI: redirectURI,
}
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
if err != nil {
log.Errorf("OAuth Web: social token exchange failed: %v", err)
h.mu.Lock()
session.status = statusFailed
session.error = fmt.Sprintf("Token exchange failed: %v", err)
session.completedAt = time.Now()
h.mu.Unlock()
h.renderError(c, session.error)
return
}
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
email := ExtractEmailFromJWT(tokenResp.AccessToken)
var provider string
if session.authMethod == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: provider,
Email: email,
Region: "us-east-1",
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if session.cancelFunc != nil {
session.cancelFunc()
}
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
h.renderSuccess(c, session)
}
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
stateID := c.Query("state")
if stateID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
response := gin.H{
"status": string(session.status),
}
switch session.status {
case statusPending:
elapsed := time.Since(session.startedAt).Seconds()
remaining := float64(session.expiresIn) - elapsed
if remaining < 0 {
remaining = 0
}
response["remaining_seconds"] = int(remaining)
case statusSuccess:
response["completed_at"] = session.completedAt.Format(time.RFC3339)
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
case statusFailed:
response["error"] = session.error
response["failed_at"] = session.completedAt.Format(time.RFC3339)
}
c.JSON(http.StatusOK, response)
}
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"AuthURL": session.authURL,
"UserCode": session.userCode,
"ExpiresIn": session.expiresIn,
"StateID": session.stateID,
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render template: %v", err)
}
}
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse select template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, nil); err != nil {
log.Errorf("OAuth Web: failed to render select template: %v", err)
}
}
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse error template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"Error": errMsg,
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.Status(http.StatusBadRequest)
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render error template: %v", err)
}
}
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse success template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render success template: %v", err)
}
}
func (h *OAuthWebHandler) CleanupExpiredSessions() {
h.mu.Lock()
defer h.mu.Unlock()
now := time.Now()
for id, session := range h.sessions {
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
delete(h.sessions, id)
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
session.cancelFunc()
delete(h.sessions, id)
}
}
}
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
session, exists := h.sessions[stateID]
return session, exists
}
// ImportTokenRequest represents the request body for token import
type ImportTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// handleImportToken handles manual refresh token import from Kiro IDE
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
var req ImportTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid request body",
})
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Refresh token is required",
})
return
}
// Validate token format
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid token format. Token should start with aorAAAAAG...",
})
return
}
// Create social auth client to refresh and validate the token
socialClient := NewSocialAuthClient(h.cfg)
// Refresh the token to validate it and get access token
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
if err != nil {
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("Token validation failed: %v", err),
})
return
}
// Set the original refresh token (the refreshed one might be empty)
if tokenData.RefreshToken == "" {
tokenData.RefreshToken = refreshToken
}
tokenData.AuthMethod = "social"
tokenData.Provider = "imported"
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
// Generate filename for response
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
if tokenData.Email != "" {
sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-")
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
}
log.Infof("OAuth Web: token imported successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token imported successfully",
"fileName": fileName,
})
}
// handleManualRefresh handles manual token refresh requests from the web UI.
// This allows users to trigger a token refresh when needed, without waiting
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": "Failed to get home directory",
})
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Find all kiro token files in the auth directory
files, err := os.ReadDir(authDir)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
})
return
}
var refreshedCount int
var errors []string
for _, file := range files {
if file.IsDir() {
continue
}
name := file.Name()
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
continue
}
filePath := filepath.Join(authDir, name)
data, err := os.ReadFile(filePath)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
continue
}
var storage KiroTokenStorage
if err := json.Unmarshal(data, &storage); err != nil {
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
continue
}
if storage.RefreshToken == "" {
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
continue
}
// Refresh token using the same logic as kiro_executor.Refresh
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
continue
}
// Update storage with new token data
storage.AccessToken = tokenData.AccessToken
if tokenData.RefreshToken != "" {
storage.RefreshToken = tokenData.RefreshToken
}
storage.ExpiresAt = tokenData.ExpiresAt
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ProfileArn != "" {
storage.ProfileArn = tokenData.ProfileArn
}
// Write updated token back to file
updatedData, err := json.MarshalIndent(storage, "", " ")
if err != nil {
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
continue
}
tmpFile := filePath + ".tmp"
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
continue
}
if err := os.Rename(tmpFile, filePath); err != nil {
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
continue
}
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
refreshedCount++
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
}
if refreshedCount == 0 && len(errors) > 0 {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
})
return
}
response := gin.H{
"success": true,
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
"refreshedCount": refreshedCount,
}
if len(errors) > 0 {
response["warnings"] = errors
}
c.JSON(http.StatusOK, response)
}
// refreshTokenData refreshes a token using the appropriate method based on auth type.
// This mirrors the logic in kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(h.cfg)
switch {
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
// IDC refresh with region-specific endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
// Builder ID refresh with default endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
default:
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
oauth := NewKiroOAuth(h.cfg)
return oauth.RefreshToken(ctx, storage.RefreshToken)
}
}

View File

@@ -0,0 +1,779 @@
// Package kiro provides OAuth Web authentication templates.
package kiro
const (
oauthWebStartPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AWS SSO Authentication</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.step {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-bottom: 15px;
}
.step-title {
display: flex;
align-items: center;
font-weight: 600;
color: #333;
margin-bottom: 10px;
}
.step-number {
width: 28px;
height: 28px;
background: #667eea;
color: white;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 14px;
margin-right: 12px;
}
.user-code {
background: #e7f3ff;
border: 2px dashed #2196F3;
border-radius: 8px;
padding: 20px;
text-align: center;
margin-top: 10px;
}
.user-code-label {
font-size: 12px;
color: #666;
text-transform: uppercase;
letter-spacing: 1px;
margin-bottom: 8px;
}
.user-code-value {
font-size: 32px;
font-weight: bold;
font-family: monospace;
color: #2196F3;
letter-spacing: 4px;
}
.auth-btn {
display: block;
width: 100%;
padding: 15px;
background: #667eea;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
margin-top: 20px;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.status {
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
text-align: center;
}
.status-pending { border-left: 4px solid #ffc107; }
.status-success { border-left: 4px solid #28a745; }
.status-failed { border-left: 4px solid #dc3545; }
.spinner {
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
width: 40px;
height: 40px;
animation: spin 1s linear infinite;
margin: 0 auto 15px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.timer {
font-size: 24px;
font-weight: bold;
color: #667eea;
margin: 10px 0;
}
.timer.warning { color: #ffc107; }
.timer.danger { color: #dc3545; }
.status-message { color: #666; line-height: 1.6; }
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 AWS SSO Authentication</h1>
<p class="subtitle">Follow the steps below to complete authentication</p>
<div class="step">
<div class="step-title">
<span class="step-number">1</span>
Click the button below to open the authorization page
</div>
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
🚀 Open Authorization Page
</a>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">2</span>
Enter the verification code below
</div>
<div class="user-code">
<div class="user-code-label">Verification Code</div>
<div class="user-code-value">{{.UserCode}}</div>
</div>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">3</span>
Complete AWS SSO login
</div>
<p style="color: #666; font-size: 14px; margin-top: 10px;">
Use your AWS SSO account to login and authorize
</p>
</div>
<div class="status status-pending" id="statusBox">
<div class="spinner" id="spinner"></div>
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
<div class="status-message" id="statusMessage">
Waiting for authorization...
</div>
</div>
<div class="info-box">
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
</div>
</div>
<script>
let pollInterval;
let timerInterval;
let remainingSeconds = {{.ExpiresIn}};
const stateID = "{{.StateID}}";
setTimeout(() => {
document.getElementById('authBtn').click();
}, 500);
function pollStatus() {
fetch('/v0/oauth/kiro/status?state=' + stateID)
.then(response => response.json())
.then(data => {
console.log('Status:', data);
if (data.status === 'success') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showSuccess(data);
} else if (data.status === 'failed') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showError(data);
} else {
remainingSeconds = data.remaining_seconds || 0;
}
})
.catch(error => {
console.error('Poll error:', error);
});
}
function updateTimer() {
const timerEl = document.getElementById('timer');
const minutes = Math.floor(remainingSeconds / 60);
const seconds = remainingSeconds % 60;
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
if (remainingSeconds < 60) {
timerEl.className = 'timer danger';
} else if (remainingSeconds < 180) {
timerEl.className = 'timer warning';
} else {
timerEl.className = 'timer';
}
remainingSeconds--;
if (remainingSeconds < 0) {
clearInterval(timerInterval);
clearInterval(pollInterval);
showError({ error: 'Authentication timed out. Please refresh and try again.' });
}
}
function showSuccess(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-success';
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
'<div class="status-message">' +
'<strong>Authentication Successful!</strong><br>' +
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
'</div>';
}
function showError(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-failed';
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
'<div class="status-message">' +
'<strong>Authentication Failed</strong><br>' +
(data.error || 'Unknown error') +
'</div>' +
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
'🔄 Retry' +
'</button>';
}
pollInterval = setInterval(pollStatus, 3000);
timerInterval = setInterval(updateTimer, 1000);
pollStatus();
</script>
</body>
</html>`
oauthWebErrorPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Failed</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.error {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #dc3545;
}
h1 { color: #dc3545; margin-top: 0; }
.error-message { color: #666; line-height: 1.6; }
.retry-btn {
display: inline-block;
margin-top: 20px;
padding: 10px 20px;
background: #007bff;
color: white;
text-decoration: none;
border-radius: 4px;
}
.retry-btn:hover { background: #0056b3; }
</style>
</head>
<body>
<div class="error">
<h1>❌ Authentication Failed</h1>
<div class="error-message">
<p><strong>Error:</strong></p>
<p>{{.Error}}</p>
</div>
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
</div>
</body>
</html>`
oauthWebSuccessPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.success {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #28a745;
text-align: center;
}
h1 { color: #28a745; margin-top: 0; }
.success-message { color: #666; line-height: 1.6; }
.icon { font-size: 48px; margin-bottom: 15px; }
.expires { font-size: 14px; color: #999; margin-top: 15px; }
</style>
</head>
<body>
<div class="success">
<div class="icon">✅</div>
<h1>Authentication Successful!</h1>
<div class="success-message">
<p>You can close this window.</p>
</div>
<div class="expires">Token expires: {{.ExpiresAt}}</div>
</div>
</body>
</html>`
oauthWebSelectPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Select Authentication Method</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.auth-methods {
display: flex;
flex-direction: column;
gap: 15px;
}
.auth-btn {
display: flex;
align-items: center;
width: 100%;
padding: 15px 20px;
background: #667eea;
color: white;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.auth-btn .icon {
font-size: 24px;
margin-right: 15px;
width: 32px;
text-align: center;
}
.auth-btn.google { background: #4285F4; }
.auth-btn.google:hover { background: #3367D6; }
.auth-btn.github { background: #24292e; }
.auth-btn.github:hover { background: #1a1e22; }
.auth-btn.aws { background: #FF9900; }
.auth-btn.aws:hover { background: #E68A00; }
.auth-btn.idc { background: #232F3E; }
.auth-btn.idc:hover { background: #1a242f; }
.idc-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.idc-form.show {
display: block;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
font-weight: 600;
color: #333;
margin-bottom: 8px;
font-size: 14px;
}
.form-group input {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.form-group .hint {
font-size: 12px;
color: #999;
margin-top: 5px;
}
.submit-btn {
display: block;
width: 100%;
padding: 15px;
background: #232F3E;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.submit-btn:hover {
background: #1a242f;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
}
.divider {
display: flex;
align-items: center;
margin: 20px 0;
}
.divider::before,
.divider::after {
content: "";
flex: 1;
border-bottom: 1px solid #e0e0e0;
}
.divider span {
padding: 0 15px;
color: #999;
font-size: 14px;
}
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
.warning-box {
background: #fff3cd;
border-left: 4px solid #ffc107;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #856404;
}
.auth-btn.manual { background: #6c757d; }
.auth-btn.manual:hover { background: #5a6268; }
.auth-btn.refresh { background: #17a2b8; }
.auth-btn.refresh:hover { background: #138496; }
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
.manual-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.manual-form.show {
display: block;
}
.form-group textarea {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
font-family: monospace;
transition: border-color 0.3s;
resize: vertical;
min-height: 80px;
}
.form-group textarea:focus {
outline: none;
border-color: #667eea;
}
.status-message {
padding: 15px;
border-radius: 6px;
margin-top: 15px;
display: none;
}
.status-message.success {
background: #d4edda;
color: #155724;
display: block;
}
.status-message.error {
background: #f8d7da;
color: #721c24;
display: block;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Select Authentication Method</h1>
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
<div class="auth-methods">
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
<span class="icon">🔶</span>
AWS Builder ID (Recommended)
</a>
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
<span class="icon">🏢</span>
AWS Identity Center (IDC)
</button>
<div class="divider"><span>or</span></div>
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
<span class="icon">📋</span>
Import RefreshToken from Kiro IDE
</button>
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
<span class="icon">🔄</span>
Manual Refresh All Tokens
</button>
<div class="status-message" id="refreshStatus"></div>
</div>
<div class="idc-form" id="idcForm">
<form action="/v0/oauth/kiro/start" method="get">
<input type="hidden" name="method" value="idc">
<div class="form-group">
<label for="startUrl">Start URL</label>
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
<div class="hint">Your AWS Identity Center Start URL</div>
</div>
<div class="form-group">
<label for="region">Region</label>
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
<div class="hint">AWS Region for your Identity Center</div>
</div>
<button type="submit" class="submit-btn">
🚀 Continue with IDC
</button>
</form>
</div>
<div class="manual-form" id="manualForm">
<form id="importForm" onsubmit="submitImport(event)">
<div class="form-group">
<label for="refreshToken">Refresh Token</label>
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
</div>
<button type="submit" class="submit-btn" id="importBtn">
📥 Import Token
</button>
<div class="status-message" id="importStatus"></div>
</form>
</div>
<div class="warning-box">
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
</div>
<div class="info-box">
💡 <strong>How to get RefreshToken:</strong><br>
1. Open Kiro IDE and login with Google/GitHub<br>
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
3. Copy the <code>refreshToken</code> value and paste it above
</div>
</div>
<script>
function toggleIdcForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
manualForm.classList.remove('show');
idcForm.classList.toggle('show');
if (idcForm.classList.contains('show')) {
document.getElementById('startUrl').focus();
}
}
function toggleManualForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
idcForm.classList.remove('show');
manualForm.classList.toggle('show');
if (manualForm.classList.contains('show')) {
document.getElementById('refreshToken').focus();
}
}
async function submitImport(event) {
event.preventDefault();
const refreshToken = document.getElementById('refreshToken').value.trim();
const statusEl = document.getElementById('importStatus');
const btn = document.getElementById('importBtn');
if (!refreshToken) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Please enter a refresh token';
return;
}
if (!refreshToken.startsWith('aorAAAAAG')) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
return;
}
btn.disabled = true;
btn.textContent = '⏳ Importing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/import', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refreshToken: refreshToken })
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.textContent = '📥 Import Token';
}
}
async function manualRefresh() {
const btn = document.getElementById('refreshBtn');
const statusEl = document.getElementById('refreshStatus');
btn.disabled = true;
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
let msg = '✅ ' + data.message;
if (data.warnings && data.warnings.length > 0) {
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
}
statusEl.textContent = msg;
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
}
}
</script>
</body>
</html>`
)

View File

@@ -0,0 +1,316 @@
package kiro
import (
"math"
"math/rand"
"strings"
"sync"
"time"
)
const (
DefaultMinTokenInterval = 10 * time.Second
DefaultMaxTokenInterval = 30 * time.Second
DefaultDailyMaxRequests = 500
DefaultJitterPercent = 0.3
DefaultBackoffBase = 2 * time.Minute
DefaultBackoffMax = 60 * time.Minute
DefaultBackoffMultiplier = 2.0
DefaultSuspendCooldown = 24 * time.Hour
)
// TokenState Token 状态
type TokenState struct {
LastRequest time.Time
RequestCount int
CooldownEnd time.Time
FailCount int
DailyRequests int
DailyResetTime time.Time
IsSuspended bool
SuspendedAt time.Time
SuspendReason string
}
// RateLimiter 频率限制器
type RateLimiter struct {
mu sync.RWMutex
states map[string]*TokenState
minTokenInterval time.Duration
maxTokenInterval time.Duration
dailyMaxRequests int
jitterPercent float64
backoffBase time.Duration
backoffMax time.Duration
backoffMultiplier float64
suspendCooldown time.Duration
rng *rand.Rand
}
// NewRateLimiter 创建默认配置的频率限制器
func NewRateLimiter() *RateLimiter {
return &RateLimiter{
states: make(map[string]*TokenState),
minTokenInterval: DefaultMinTokenInterval,
maxTokenInterval: DefaultMaxTokenInterval,
dailyMaxRequests: DefaultDailyMaxRequests,
jitterPercent: DefaultJitterPercent,
backoffBase: DefaultBackoffBase,
backoffMax: DefaultBackoffMax,
backoffMultiplier: DefaultBackoffMultiplier,
suspendCooldown: DefaultSuspendCooldown,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// RateLimiterConfig 频率限制器配置
type RateLimiterConfig struct {
MinTokenInterval time.Duration
MaxTokenInterval time.Duration
DailyMaxRequests int
JitterPercent float64
BackoffBase time.Duration
BackoffMax time.Duration
BackoffMultiplier float64
SuspendCooldown time.Duration
}
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
rl := NewRateLimiter()
if cfg.MinTokenInterval > 0 {
rl.minTokenInterval = cfg.MinTokenInterval
}
if cfg.MaxTokenInterval > 0 {
rl.maxTokenInterval = cfg.MaxTokenInterval
}
if cfg.DailyMaxRequests > 0 {
rl.dailyMaxRequests = cfg.DailyMaxRequests
}
if cfg.JitterPercent > 0 {
rl.jitterPercent = cfg.JitterPercent
}
if cfg.BackoffBase > 0 {
rl.backoffBase = cfg.BackoffBase
}
if cfg.BackoffMax > 0 {
rl.backoffMax = cfg.BackoffMax
}
if cfg.BackoffMultiplier > 0 {
rl.backoffMultiplier = cfg.BackoffMultiplier
}
if cfg.SuspendCooldown > 0 {
rl.suspendCooldown = cfg.SuspendCooldown
}
return rl
}
// getOrCreateState 获取或创建 Token 状态
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
state, exists := rl.states[tokenKey]
if !exists {
state = &TokenState{
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
}
rl.states[tokenKey] = state
}
return state
}
// resetDailyIfNeeded 如果需要则重置每日计数
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
now := time.Now()
if now.After(state.DailyResetTime) {
state.DailyRequests = 0
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
}
}
// calculateInterval 计算带抖动的随机间隔
func (rl *RateLimiter) calculateInterval() time.Duration {
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
return baseInterval + jitter
}
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
func (rl *RateLimiter) WaitForToken(tokenKey string) {
rl.mu.Lock()
state := rl.getOrCreateState(tokenKey)
rl.resetDailyIfNeeded(state)
now := time.Now()
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
waitTime := state.CooldownEnd.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
now = time.Now()
}
// 计算距离上次请求的间隔
interval := rl.calculateInterval()
nextAllowedTime := state.LastRequest.Add(interval)
if now.Before(nextAllowedTime) {
waitTime := nextAllowedTime.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
}
state.LastRequest = time.Now()
state.RequestCount++
state.DailyRequests++
rl.mu.Unlock()
}
// MarkTokenFailed 标记 Token 失败
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount++
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
}
// MarkTokenSuccess 标记 Token 成功
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount = 0
state.CooldownEnd = time.Time{}
}
// CheckAndMarkSuspended 检测暂停错误并标记
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
suspendKeywords := []string{
"suspended",
"banned",
"disabled",
"account has been",
"access denied",
"rate limit exceeded",
"too many requests",
"quota exceeded",
}
lowerMsg := strings.ToLower(errorMsg)
for _, keyword := range suspendKeywords {
if strings.Contains(lowerMsg, keyword) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.IsSuspended = true
state.SuspendedAt = time.Now()
state.SuspendReason = errorMsg
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
return true
}
}
return false
}
// IsTokenAvailable 检查 Token 是否可用
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return true
}
now := time.Now()
// 检查是否被暂停
if state.IsSuspended {
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
return true
}
return false
}
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
return false
}
// 检查每日请求限制
rl.mu.RUnlock()
rl.mu.Lock()
rl.resetDailyIfNeeded(state)
dailyRequests := state.DailyRequests
dailyMax := rl.dailyMaxRequests
rl.mu.Unlock()
rl.mu.RLock()
if dailyRequests >= dailyMax {
return false
}
return true
}
// calculateBackoff 计算指数退避时间
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
if failCount <= 0 {
return 0
}
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
// 添加抖动
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
backoff += jitter
if time.Duration(backoff) > rl.backoffMax {
return rl.backoffMax
}
return time.Duration(backoff)
}
// GetTokenState 获取 Token 状态(只读)
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return nil
}
// 返回副本以防止外部修改
stateCopy := *state
return &stateCopy
}
// ClearTokenState 清除 Token 状态
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.states, tokenKey)
}
// ResetSuspension 重置暂停状态
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state, exists := rl.states[tokenKey]
if exists {
state.IsSuspended = false
state.SuspendedAt = time.Time{}
state.SuspendReason = ""
state.CooldownEnd = time.Time{}
state.FailCount = 0
}
}

View File

@@ -0,0 +1,46 @@
package kiro
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
globalRateLimiter *RateLimiter
globalRateLimiterOnce sync.Once
globalCooldownManager *CooldownManager
globalCooldownManagerOnce sync.Once
cooldownStopCh chan struct{}
)
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
func GetGlobalRateLimiter() *RateLimiter {
globalRateLimiterOnce.Do(func() {
globalRateLimiter = NewRateLimiter()
log.Info("kiro: global RateLimiter initialized")
})
return globalRateLimiter
}
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
func GetGlobalCooldownManager() *CooldownManager {
globalCooldownManagerOnce.Do(func() {
globalCooldownManager = NewCooldownManager()
cooldownStopCh = make(chan struct{})
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
log.Info("kiro: global CooldownManager initialized with cleanup routine")
})
return globalCooldownManager
}
// ShutdownRateLimiters stops the cooldown cleanup routine.
// Should be called during application shutdown.
func ShutdownRateLimiters() {
if cooldownStopCh != nil {
close(cooldownStopCh)
log.Info("kiro: rate limiter cleanup routine stopped")
}
}

View File

@@ -0,0 +1,304 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter()
if rl == nil {
t.Fatal("expected non-nil RateLimiter")
}
if rl.states == nil {
t.Error("expected non-nil states map")
}
if rl.minTokenInterval != DefaultMinTokenInterval {
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
}
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
MaxTokenInterval: 15 * time.Second,
DailyMaxRequests: 100,
JitterPercent: 0.2,
BackoffBase: 1 * time.Minute,
BackoffMax: 30 * time.Minute,
BackoffMultiplier: 1.5,
SuspendCooldown: 12 * time.Hour,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != 15*time.Second {
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
}
if rl.dailyMaxRequests != 100 {
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
}
}
func TestGetTokenState_NonExistent(t *testing.T) {
rl := NewRateLimiter()
state := rl.GetTokenState("nonexistent")
if state != nil {
t.Error("expected nil state for non-existent token")
}
}
func TestIsTokenAvailable_NewToken(t *testing.T) {
rl := NewRateLimiter()
if !rl.IsTokenAvailable("newtoken") {
t.Error("expected new token to be available")
}
}
func TestMarkTokenFailed(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", state.FailCount)
}
if state.CooldownEnd.IsZero() {
t.Error("expected non-zero CooldownEnd")
}
}
func TestMarkTokenSuccess(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.MarkTokenFailed("token1")
rl.MarkTokenSuccess("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
if !state.CooldownEnd.IsZero() {
t.Error("expected zero CooldownEnd after success")
}
}
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
rl := NewRateLimiter()
testCases := []string{
"Account has been suspended",
"You are banned from this service",
"Account disabled",
"Access denied permanently",
"Rate limit exceeded",
"Too many requests",
"Quota exceeded for today",
}
for i, msg := range testCases {
tokenKey := "token" + string(rune('a'+i))
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("expected suspension detected for: %s", msg)
}
state := rl.GetTokenState(tokenKey)
if !state.IsSuspended {
t.Errorf("expected IsSuspended true for: %s", msg)
}
}
}
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
rl := NewRateLimiter()
normalErrors := []string{
"connection timeout",
"internal server error",
"bad request",
"invalid token format",
}
for i, msg := range normalErrors {
tokenKey := "token" + string(rune('a'+i))
if rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("unexpected suspension for: %s", msg)
}
}
}
func TestIsTokenAvailable_Suspended(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
if rl.IsTokenAvailable("token1") {
t.Error("expected suspended token to be unavailable")
}
}
func TestClearTokenState(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.ClearTokenState("token1")
state := rl.GetTokenState("token1")
if state != nil {
t.Error("expected nil state after clear")
}
}
func TestResetSuspension(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
rl.ResetSuspension("token1")
state := rl.GetTokenState("token1")
if state.IsSuspended {
t.Error("expected IsSuspended false after reset")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
}
func TestResetSuspension_NonExistent(t *testing.T) {
rl := NewRateLimiter()
rl.ResetSuspension("nonexistent")
}
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
rl := NewRateLimiter()
backoff := rl.calculateBackoff(0)
if backoff != 0 {
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
}
}
func TestCalculateBackoff_Exponential(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 60 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
backoff1 := rl.calculateBackoff(1)
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
}
backoff2 := rl.calculateBackoff(2)
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
}
}
func TestCalculateBackoff_MaxCap(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 10 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0,
}
rl := NewRateLimiterWithConfig(cfg)
backoff := rl.calculateBackoff(10)
if backoff > 10*time.Minute {
t.Errorf("expected backoff capped at 10min, got %v", backoff)
}
}
func TestGetTokenState_ReturnsCopy(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state1 := rl.GetTokenState("token1")
state1.FailCount = 999
state2 := rl.GetTokenState("token1")
if state2.FailCount == 999 {
t.Error("GetTokenState should return a copy")
}
}
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
rl := NewRateLimiter()
const numGoroutines = 50
const numOperations = 50
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
rl.IsTokenAvailable(tokenKey)
case 1:
rl.MarkTokenFailed(tokenKey)
case 2:
rl.MarkTokenSuccess(tokenKey)
case 3:
rl.GetTokenState(tokenKey)
case 4:
rl.CheckAndMarkSuspended(tokenKey, "test error")
case 5:
rl.ResetSuspension(tokenKey)
}
}
}(i)
}
wg.Wait()
}
func TestCalculateInterval_WithinRange(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 10 * time.Second,
MaxTokenInterval: 30 * time.Second,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
minAllowed := 7 * time.Second
maxAllowed := 40 * time.Second
for i := 0; i < 100; i++ {
interval := rl.calculateInterval()
if interval < minAllowed || interval > maxAllowed {
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
}
}
}

View File

@@ -0,0 +1,171 @@
package kiro
import (
"context"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// RefreshManager 是后台刷新器的单例管理器
type RefreshManager struct {
mu sync.Mutex
refresher *BackgroundRefresher
ctx context.Context
cancel context.CancelFunc
started bool
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
}
var (
globalRefreshManager *RefreshManager
managerOnce sync.Once
)
// GetRefreshManager 获取全局刷新管理器实例
func GetRefreshManager() *RefreshManager {
managerOnce.Do(func() {
globalRefreshManager = &RefreshManager{}
})
return globalRefreshManager
}
// Initialize 初始化后台刷新器
// baseDir: token 文件所在的目录
// cfg: 应用配置
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already initialized")
return nil
}
if baseDir == "" {
log.Warn("refresh manager: base directory not provided, skipping initialization")
return nil
}
// 创建 token 存储库
repo := NewFileTokenRepository(baseDir)
// 创建后台刷新器,配置参数
opts := []RefresherOption{
WithInterval(time.Minute), // 每分钟检查一次
WithBatchSize(50), // 每批最多处理 50 个 token
WithConcurrency(10), // 最多 10 个并发刷新
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
}
// 如果已设置回调,传递给 BackgroundRefresher
if m.onTokenRefreshed != nil {
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
}
m.refresher = NewBackgroundRefresher(repo, opts...)
log.Infof("refresh manager: initialized with base directory %s", baseDir)
return nil
}
// Start 启动后台刷新
func (m *RefreshManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already started")
return
}
if m.refresher == nil {
log.Warn("refresh manager: not initialized, cannot start")
return
}
m.ctx, m.cancel = context.WithCancel(context.Background())
m.refresher.Start(m.ctx)
m.started = true
log.Info("refresh manager: background refresh started")
}
// Stop 停止后台刷新
func (m *RefreshManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.started {
return
}
if m.cancel != nil {
m.cancel()
}
if m.refresher != nil {
m.refresher.Stop()
}
m.started = false
log.Info("refresh manager: background refresh stopped")
}
// IsRunning 检查后台刷新是否正在运行
func (m *RefreshManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.started
}
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.refresher != nil && m.refresher.tokenRepo != nil {
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
repo.SetBaseDir(baseDir)
log.Infof("refresh manager: updated base directory to %s", baseDir)
}
}
}
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
// 可以在任何时候调用,支持运行时更新回调
// callback: 回调函数,接收 tokenID文件名和新的 token 数据
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
m.mu.Lock()
defer m.mu.Unlock()
m.onTokenRefreshed = callback
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
if m.refresher != nil {
m.refresher.callbackMu.Lock()
m.refresher.onTokenRefreshed = callback
m.refresher.callbackMu.Unlock()
}
log.Debug("refresh manager: token refresh callback registered")
}
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
func InitializeAndStart(baseDir string, cfg *config.Config) {
manager := GetRefreshManager()
if err := manager.Initialize(baseDir, cfg); err != nil {
log.Errorf("refresh manager: initialization failed: %v", err)
return
}
manager.Start()
}
// StopGlobalRefreshManager 停止全局刷新管理器
func StopGlobalRefreshManager() {
if globalRefreshManager != nil {
globalRefreshManager.Stop()
}
}

View File

@@ -9,7 +9,9 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"html"
"io"
"net"
"net/http"
"net/url"
"os"
@@ -31,6 +33,9 @@ const (
// OAuth timeout
socialAuthTimeout = 10 * time.Minute
// Default callback port for social auth HTTP server
socialAuthCallbackPort = 9876
)
// SocialProvider represents the social login provider.
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// WebCallbackResult contains the OAuth callback result from HTTP server.
type WebCallbackResult struct {
Code string
State string
Error string
}
// SocialAuthClient handles social authentication with Kiro.
type SocialAuthClient struct {
httpClient *http.Client
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
}
}
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
// Try to find an available port - use localhost like Kiro does
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
if err != nil {
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
}
}
port := listener.Addr().(*net.TCPAddr).Port
// Use http scheme for local callback server
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
resultChan := make(chan WebCallbackResult, 1)
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
if errParam != "" {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- WebCallbackResult{Error: errParam}
return
}
if state != expectedState {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- WebCallbackResult{Error: "state mismatch"}
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Successful</title></head>
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
<script>window.close();</script></body></html>`)
resultChan <- WebCallbackResult{Code: code, State: state}
})
server.Handler = mux
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("kiro social auth callback server error: %v", err)
}
}()
go func() {
select {
case <-ctx.Done():
case <-time.After(socialAuthTimeout):
case <-resultChan:
}
_ = server.Shutdown(context.Background())
}()
return redirectURI, resultChan, nil
}
// generatePKCE generates PKCE code verifier and challenge.
func generatePKCE() (verifier, challenge string, err error) {
// Generate 32 bytes of random data for verifier
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// LoginWithSocial performs OAuth login with Google.
// LoginWithSocial performs OAuth login with Google or GitHub.
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
providerName := string(provider)
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Setup protocol handler
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
// This avoids redirect_mismatch errors with AWS Cognito
fmt.Println("\nSetting up authentication...")
// Start the local callback server
handlerPort, err := c.protocolHandler.Start(ctx)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
defer c.protocolHandler.Stop()
// Ensure protocol handler is installed and set as default
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
log.Debugf("kiro: protocol handler setup error: %v", err)
// Continue anyway - user might have set it up manually or select browser manually
} else {
// Force set our handler as default (prevents "Open with" dialog)
forceDefaultProtocolHandler()
}
// Step 2: Generate PKCE codes
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 4: Build the login URL (Kiro uses GET request with query params)
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
// Step 4: Start local HTTP callback server
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
// Step 5: Build the login URL using HTTP redirect URI
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
// Incognito mode enables multi-account support by bypassing cached sessions
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Step 5: Open browser for user authentication
// Step 6: Open browser for user authentication
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
fmt.Println("════════════════════════════════════════════════════════════")
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
fmt.Println("\n Waiting for authentication callback...")
// Step 6: Wait for callback
callback, err := c.protocolHandler.WaitForCallback(ctx)
if err != nil {
return nil, fmt.Errorf("failed to receive callback: %w", err)
}
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
if callback.State != state {
// Log state values for debugging, but don't expose in user-facing error
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
return nil, fmt.Errorf("OAuth state validation failed - please try again")
}
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 7: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: KiroRedirectURI,
}
tokenResp, err := c.CreateToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
// Step 7: Wait for callback from HTTP server
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(socialAuthTimeout):
return nil, fmt.Errorf("authentication timed out")
case callback := <-resultChan:
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
}, nil
// State is already validated by the callback server
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 8: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
}
tokenResp, err := c.CreateToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
Region: "us-east-1",
}, nil
}
}
// LoginWithGoogle performs OAuth login with Google.

View File

@@ -735,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
Region: defaultIDCRegion,
}, nil
}
@@ -850,16 +851,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
}
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
// Falls back to JWT parsing if userinfo fails.
@@ -1366,6 +1368,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}

View File

@@ -9,6 +9,8 @@ import (
// KiroTokenStorage holds the persistent token data for Kiro authentication.
type KiroTokenStorage struct {
// Type is the provider type for management UI recognition (must be "kiro")
Type string `json:"type"`
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
Provider string `json:"provider"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
// ClientID is the OAuth client ID (required for token refresh)
ClientID string `json:"client_id,omitempty"`
// ClientSecret is the OAuth client secret (required for token refresh)
ClientSecret string `json:"client_secret,omitempty"`
// Region is the AWS region
Region string `json:"region,omitempty"`
// StartURL is the AWS Identity Center start URL (for IDC auth)
StartURL string `json:"start_url,omitempty"`
// Email is the user's email address
Email string `json:"email,omitempty"`
}
// SaveTokenToFile persists the token storage to the specified file path.
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
ExpiresAt: s.ExpiresAt,
AuthMethod: s.AuthMethod,
Provider: s.Provider,
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
Region: s.Region,
StartURL: s.StartURL,
Email: s.Email,
}
}

View File

@@ -0,0 +1,273 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
type FileTokenRepository struct {
mu sync.RWMutex
baseDir string
}
// NewFileTokenRepository 创建一个新的文件 token 存储库
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
return &FileTokenRepository{
baseDir: baseDir,
}
}
// SetBaseDir 设置基础目录
func (r *FileTokenRepository) SetBaseDir(dir string) {
r.mu.Lock()
r.baseDir = strings.TrimSpace(dir)
r.mu.Unlock()
}
// FindOldestUnverified 查找需要刷新的 token按最后验证时间排序
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
log.Debug("token repository: base directory not configured")
return nil
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil // 忽略错误,继续遍历
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
// 只处理 kiro 相关的 token 文件
if !strings.HasPrefix(d.Name(), "kiro-") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
log.Debugf("token repository: failed to read token file %s: %v", path, err)
return nil
}
if token != nil && token.RefreshToken != "" {
// 检查 token 是否需要刷新(过期前 5 分钟)
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
tokens = append(tokens, token)
}
}
return nil
})
if err != nil {
log.Warnf("token repository: error walking directory: %v", err)
}
// 按最后验证时间排序(最旧的优先)
sort.Slice(tokens, func(i, j int) bool {
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
})
// 限制返回数量
if limit > 0 && len(tokens) > limit {
tokens = tokens[:limit]
}
return tokens
}
// UpdateToken 更新 token 并持久化到文件
func (r *FileTokenRepository) UpdateToken(token *Token) error {
if token == nil {
return fmt.Errorf("token repository: token is nil")
}
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return fmt.Errorf("token repository: base directory not configured")
}
// 构建文件路径
filePath := filepath.Join(baseDir, token.ID)
if !strings.HasSuffix(filePath, ".json") {
filePath += ".json"
}
// 读取现有文件内容
existingData := make(map[string]any)
if data, err := os.ReadFile(filePath); err == nil {
_ = json.Unmarshal(data, &existingData)
}
// 更新字段
existingData["access_token"] = token.AccessToken
existingData["refresh_token"] = token.RefreshToken
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
if !token.ExpiresAt.IsZero() {
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
}
// 保持原有的关键字段
if token.ClientID != "" {
existingData["client_id"] = token.ClientID
}
if token.ClientSecret != "" {
existingData["client_secret"] = token.ClientSecret
}
if token.AuthMethod != "" {
existingData["auth_method"] = token.AuthMethod
}
if token.Region != "" {
existingData["region"] = token.Region
}
if token.StartURL != "" {
existingData["start_url"] = token.StartURL
}
// 序列化并写入文件
raw, err := json.MarshalIndent(existingData, "", " ")
if err != nil {
return fmt.Errorf("token repository: marshal failed: %w", err)
}
// 原子写入:先写入临时文件,再重命名
tmpPath := filePath + ".tmp"
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
return fmt.Errorf("token repository: write temp file failed: %w", err)
}
if err := os.Rename(tmpPath, filePath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("token repository: rename failed: %w", err)
}
log.Debugf("token repository: updated token %s", token.ID)
return nil
}
// readTokenFile 从文件读取 token
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var metadata map[string]any
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, err
}
// 检查是否是 kiro token
tokenType, _ := metadata["type"].(string)
if tokenType != "kiro" {
return nil, nil
}
// 检查 auth_method
authMethod, _ := metadata["auth_method"].(string)
if authMethod != "idc" && authMethod != "builder-id" {
return nil, nil // 只处理 IDC 和 Builder ID token
}
token := &Token{
ID: filepath.Base(path),
AuthMethod: authMethod,
}
// 解析各字段
if v, ok := metadata["access_token"].(string); ok {
token.AccessToken = v
}
if v, ok := metadata["refresh_token"].(string); ok {
token.RefreshToken = v
}
if v, ok := metadata["client_id"].(string); ok {
token.ClientID = v
}
if v, ok := metadata["client_secret"].(string); ok {
token.ClientSecret = v
}
if v, ok := metadata["region"].(string); ok {
token.Region = v
}
if v, ok := metadata["start_url"].(string); ok {
token.StartURL = v
}
if v, ok := metadata["provider"].(string); ok {
token.Provider = v
}
// 解析时间字段
if v, ok := metadata["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
token.ExpiresAt = t
}
}
if v, ok := metadata["last_refresh"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
token.LastVerified = t
}
}
return token, nil
}
// ListKiroTokens 列出所有 Kiro token用于调试
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return nil, fmt.Errorf("token repository: base directory not configured")
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil
}
if d.IsDir() {
return nil
}
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
return nil
}
if token != nil {
tokens = append(tokens, token)
}
return nil
})
return tokens, err
}

View File

@@ -0,0 +1,243 @@
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
// This file implements usage quota checking and monitoring.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
// UsageQuotaResponse represents the API response structure for usage quota checking.
type UsageQuotaResponse struct {
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
NextDateReset float64 `json:"nextDateReset,omitempty"`
}
// UsageBreakdownExtended represents detailed usage information for quota checking.
// Note: UsageBreakdown is already defined in codewhisperer_client.go
type UsageBreakdownExtended struct {
ResourceType string `json:"resourceType"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
}
// FreeTrialInfoExtended represents free trial usage information.
type FreeTrialInfoExtended struct {
FreeTrialStatus string `json:"freeTrialStatus"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
}
// QuotaStatus represents the quota status for a token.
type QuotaStatus struct {
TotalLimit float64
CurrentUsage float64
RemainingQuota float64
IsExhausted bool
ResourceType string
NextReset time.Time
}
// UsageChecker provides methods for checking token quota usage.
type UsageChecker struct {
httpClient *http.Client
endpoint string
}
// NewUsageChecker creates a new UsageChecker instance.
func NewUsageChecker(cfg *config.Config) *UsageChecker {
return &UsageChecker{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
return &UsageChecker{
httpClient: client,
endpoint: awsKiroEndpoint,
}
}
// CheckUsage retrieves usage limits for the given token.
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
if tokenData == nil {
return nil, fmt.Errorf("token data is nil")
}
if tokenData.AccessToken == "" {
return nil, fmt.Errorf("access token is empty")
}
payload := map[string]interface{}{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", targetGetUsage)
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result UsageQuotaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse usage response: %w", err)
}
return &result, nil
}
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
tokenData := &KiroTokenData{
AccessToken: accessToken,
ProfileArn: profileArn,
}
return c.CheckUsage(ctx, tokenData)
}
// GetRemainingQuota calculates the remaining quota from usage limits.
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 0
}
var totalRemaining float64
for _, breakdown := range usage.UsageBreakdownList {
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
if remaining > 0 {
totalRemaining += remaining
}
if breakdown.FreeTrialInfo != nil {
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
totalRemaining += freeRemaining
}
}
}
return totalRemaining
}
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return true
}
for _, breakdown := range usage.UsageBreakdownList {
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
return false
}
if breakdown.FreeTrialInfo != nil {
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
return false
}
}
}
return true
}
// GetQuotaStatus retrieves a comprehensive quota status for a token.
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
usage, err := c.CheckUsage(ctx, tokenData)
if err != nil {
return nil, err
}
status := &QuotaStatus{
IsExhausted: IsQuotaExhausted(usage),
}
if len(usage.UsageBreakdownList) > 0 {
breakdown := usage.UsageBreakdownList[0]
status.TotalLimit = breakdown.UsageLimitWithPrecision
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
status.ResourceType = breakdown.ResourceType
if breakdown.FreeTrialInfo != nil {
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
status.RemainingQuota += freeRemaining
}
}
}
if usage.NextDateReset > 0 {
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
}
return status, nil
}
// CalculateAvailableCount calculates the available request count based on usage limits.
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
return GetRemainingQuota(usage)
}
// GetUsagePercentage calculates the usage percentage.
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 100.0
}
var totalLimit, totalUsage float64
for _, breakdown := range usage.UsageBreakdownList {
totalLimit += breakdown.UsageLimitWithPrecision
totalUsage += breakdown.CurrentUsageWithPrecision
if breakdown.FreeTrialInfo != nil {
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
}
}
if totalLimit == 0 {
return 100.0
}
return (totalUsage / totalLimit) * 100
}

View File

@@ -3,6 +3,8 @@ package cache
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"time"
)
@@ -94,17 +96,17 @@ func purgeExpiredSessions() {
// CacheSignature stores a thinking signature for a given session and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {
if sessionID == "" || text == "" || signature == "" {
func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
sc := getOrCreateSession(sessionID)
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
textHash := hashText(text)
sc := getOrCreateSession(textHash)
sc.mu.Lock()
defer sc.mu.Unlock()
@@ -116,13 +118,21 @@ func CacheSignature(sessionID, text, signature string) {
// GetCachedSignature retrieves a cached signature for a given session and text.
// Returns empty string if not found or expired.
func GetCachedSignature(sessionID, text string) string {
if sessionID == "" || text == "" {
func GetCachedSignature(modelName, text string) string {
family := GetModelGroup(modelName)
if text == "" {
if family == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
val, ok := signatureCache.Load(sessionID)
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
val, ok := signatureCache.Load(hashText(text))
if !ok {
if family == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*sessionCache)
@@ -135,11 +145,17 @@ func GetCachedSignature(sessionID, text string) string {
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
if family == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
if family == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
@@ -154,7 +170,13 @@ func GetCachedSignature(sessionID, text string) string {
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Delete(sessionID)
signatureCache.Range(func(key, _ any) bool {
kStr, ok := key.(string)
if ok && strings.HasSuffix(kStr, "#"+sessionID) {
signatureCache.Delete(key)
}
return true
})
} else {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
@@ -164,6 +186,17 @@ func ClearSignatureCache(sessionID string) {
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(signature string) bool {
return signature != "" && len(signature) >= MinValidSignatureLen
func HasValidSignature(modelName, signature string) bool {
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
}
func GetModelGroup(modelName string) string {
if strings.Contains(modelName, "gpt") {
return "gpt"
} else if strings.Contains(modelName, "claude") {
return "claude"
} else if strings.Contains(modelName, "gemini") {
return "gemini"
}
return modelName
}

View File

@@ -8,15 +8,14 @@ import (
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-session-1"
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(sessionID, text, signature)
CacheSignature("test-model", text, signature)
// Retrieve signature
retrieved := GetCachedSignature(sessionID, text)
retrieved := GetCachedSignature("test-model", text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
@@ -29,13 +28,13 @@ func TestCacheSignature_DifferentSessions(t *testing.T) {
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("session-a", text, sig1)
CacheSignature("session-b", text, sig2)
CacheSignature("test-model", text, sig1)
CacheSignature("test-model", text, sig2)
if GetCachedSignature("session-a", text) != sig1 {
if GetCachedSignature("test-model", text) != sig1 {
t.Error("Session-a signature mismatch")
}
if GetCachedSignature("session-b", text) != sig2 {
if GetCachedSignature("test-model", text) != sig2 {
t.Error("Session-b signature mismatch")
}
}
@@ -44,13 +43,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
if got := GetCachedSignature("test-model", "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("session-x", "text-b"); got != "" {
CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("test-model", "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
@@ -59,12 +58,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "text", "")
CacheSignature("session", "text", "short") // Too short
CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "text", "")
CacheSignature("test-model", "text", "short") // Too short
if got := GetCachedSignature("session", "text"); got != "" {
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
@@ -72,13 +71,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-short-sig"
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(sessionID, text, shortSig)
CacheSignature("test-model", text, shortSig)
if got := GetCachedSignature(sessionID, text); got != "" {
if got := GetCachedSignature("test-model", text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
@@ -87,15 +85,15 @@ func TestClearSignatureCache_SpecificSession(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature("test-model", "text", sig)
CacheSignature("test-model", "text", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature("session-1", "text"); got != "" {
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != sig {
if got := GetCachedSignature("test-model", "text"); got != sig {
t.Error("session-2 should still exist")
}
}
@@ -104,15 +102,15 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature("test-model", "text", sig)
CacheSignature("test-model", "text", sig)
ClearSignatureCache("")
if got := GetCachedSignature("session-1", "text"); got != "" {
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != "" {
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-2 should be cleared")
}
}
@@ -132,7 +130,7 @@ func TestHasValidSignature(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.signature)
result := HasValidSignature("claude-sonnet-4-5-thinking", tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
@@ -143,21 +141,19 @@ func TestHasValidSignature(t *testing.T) {
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
sessionID := "hash-test-session"
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text1, sig1)
CacheSignature(sessionID, text2, sig2)
CacheSignature("test-model", text1, sig1)
CacheSignature("test-model", text2, sig2)
if GetCachedSignature(sessionID, text1) != sig1 {
if GetCachedSignature("test-model", text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(sessionID, text2) != sig2 {
if GetCachedSignature("test-model", text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
@@ -165,13 +161,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
sessionID := "unicode-session"
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(sessionID, text, sig)
CacheSignature("test-model", text, sig)
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature("test-model", text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
@@ -179,15 +174,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
sessionID := "overwrite-session"
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(sessionID, text, sig1)
CacheSignature(sessionID, text, sig2) // Overwrite
CacheSignature("test-model", text, sig1)
CacheSignature("test-model", text, sig2) // Overwrite
if got := GetCachedSignature(sessionID, text); got != sig2 {
if got := GetCachedSignature("test-model", text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
@@ -199,14 +193,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
sessionID := "expiration-test"
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text, sig)
CacheSignature("test-model", text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature("test-model", text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}

View File

@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)

View File

@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
}
loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID,
Metadata: map[string]string{},
Prompt: callbackPrompt,
NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: callbackPrompt,
}
authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
Prompt: callbackPrompt,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Prompt: callbackPrompt,
})
if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient)
@@ -116,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
}
activatedProjects := make([]string, 0, len(projectSelections))
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
@@ -132,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
if finalID == "" {
finalID = candidateID
}
// Skip duplicates
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
@@ -259,8 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
log.Warnf("Gemini onboarding returned project %s instead of requested %s; using response project ID.", responseProjectID, projectID)
finalProjectID = responseProjectID
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// Interactive prompt for free users
fmt.Printf("\nGoogle returned a different project ID:\n")
fmt.Printf(" Requested (frontend): %s\n", projectID)
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
fmt.Printf(" This is normal for free tier users.\n\n")
fmt.Printf("Which project ID would you like to use?\n")
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
fmt.Printf("Enter choice [1]: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)
if choice == "2" {
log.Infof("Using frontend project ID: %s", projectID)
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
finalProjectID = projectID
} else {
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
finalProjectID = responseProjectID
}
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}

View File

@@ -19,6 +19,9 @@ type LoginOptions struct {
// NoBrowser indicates whether to skip opening the browser automatically.
NoBrowser bool
// CallbackPort overrides the local OAuth callback port when set (>0).
CallbackPort int
// Prompt allows the caller to provide interactive input when needed.
Prompt func(prompt string) (string, error)
}
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)

View File

@@ -6,12 +6,14 @@ package config
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
)
@@ -69,6 +71,11 @@ type Config struct {
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
// CodexInstructionsEnabled controls whether official Codex instructions are injected.
// When false (default), CodexInstructionsForModel returns immediately without modification.
// When true, the original instruction injection logic is used.
CodexInstructionsEnabled bool `yaml:"codex-instructions-enabled" json:"codex-instructions-enabled"`
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
@@ -99,13 +106,13 @@ type Config struct {
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
// These mappings affect both model listing and model routing for supported channels:
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -158,11 +165,11 @@ type RoutingConfig struct {
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// ModelNameMapping defines a model ID mapping for a specific channel.
// OAuthModelAlias defines a model ID alias for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type ModelNameMapping struct {
type OAuthModelAlias struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
@@ -229,8 +236,12 @@ type AmpUpstreamAPIKeyEntry struct {
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
Default []PayloadRule `yaml:"default" json:"default"`
// DefaultRaw defines rules that set raw JSON values only when they are missing.
DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"`
// Override defines rules that always set parameters, overwriting any existing values.
Override []PayloadRule `yaml:"override" json:"override"`
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
}
// PayloadRule describes a single rule targeting a list of models with parameter updates.
@@ -238,6 +249,7 @@ type PayloadRule struct {
// Models lists model entries with name pattern and protocol constraint.
Models []PayloadModelRule `yaml:"models" json:"models"`
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
// For *-raw rules, values are treated as raw JSON fragments (strings are used as-is).
Params map[string]any `yaml:"params" json:"params"`
}
@@ -249,12 +261,35 @@ type PayloadModelRule struct {
Protocol string `yaml:"protocol" json:"protocol"`
}
// CloakConfig configures request cloaking for non-Claude-Code clients.
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
type CloakConfig struct {
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
// - "always": always apply cloaking regardless of client
// - "never": never apply cloaking
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// StrictMode controls how system prompts are handled when cloaking.
// - false (default): prepend Claude Code prompt to user system messages
// - true: strip all user system messages, keep only Claude Code prompt
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
// including the API key itself and an optional base URL for the API endpoint.
type ClaudeKey struct {
// APIKey is the authentication key for accessing Claude API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -273,8 +308,14 @@ type ClaudeKey struct {
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
// Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
}
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
func (k ClaudeKey) GetBaseURL() string { return k.BaseURL }
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
type ClaudeModel struct {
// Name is the upstream model identifier used when issuing requests.
@@ -293,6 +334,10 @@ type CodexKey struct {
// APIKey is the authentication key for accessing Codex API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -313,6 +358,9 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k CodexKey) GetAPIKey() string { return k.APIKey }
func (k CodexKey) GetBaseURL() string { return k.BaseURL }
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
@@ -331,6 +379,10 @@ type GeminiKey struct {
// APIKey is the authentication key for accessing Gemini API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -350,6 +402,9 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k GeminiKey) GetAPIKey() string { return k.APIKey }
func (k GeminiKey) GetBaseURL() string { return k.BaseURL }
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
@@ -397,6 +452,10 @@ type OpenAICompatibility struct {
// Name is the identifier for this OpenAI compatibility configuration.
Name string `yaml:"name" json:"name"`
// Priority controls selection preference when multiple providers or credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -432,6 +491,9 @@ type OpenAICompatibilityModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
func (m OpenAICompatibilityModel) GetAlias() string { return m.Alias }
// LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, applies environment variable overrides,
// and returns it.
@@ -450,6 +512,15 @@ func LoadConfig(configFile string) (*Config, error) {
// If optional is true and the file is missing, it returns an empty Config.
// If optional is true and the file is empty or invalid, it returns an empty Config.
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Perform oauth-model-alias migration before loading config.
// This migrates oauth-model-mappings to oauth-model-alias if needed.
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
// Log warning but don't fail - config loading should still work
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
} else if migrated {
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
}
// Read the entire configuration file into memory.
data, err := os.ReadFile(configFile)
if err != nil {
@@ -546,8 +617,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name mappings.
cfg.SanitizeOAuthModelMappings()
// Normalize global OAuth model name aliases.
cfg.SanitizeOAuthModelAlias()
// Validate raw payload rules and drop invalid entries.
cfg.SanitizePayloadRules()
if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
@@ -565,24 +639,79 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil
}
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
func (cfg *Config) SanitizeOAuthModelMappings() {
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules.
func (cfg *Config) SanitizePayloadRules() {
if cfg == nil {
return
}
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
for rawChannel, mappings := range cfg.OAuthModelMappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(mappings) == 0 {
cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw")
cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw")
}
func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule {
if len(rules) == 0 {
return rules
}
out := make([]PayloadRule, 0, len(rules))
for i := range rules {
rule := rules[i]
if len(rule.Params) == 0 {
continue
}
seenAlias := make(map[string]struct{}, len(mappings))
clean := make([]ModelNameMapping, 0, len(mappings))
for _, mapping := range mappings {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
invalid := false
for path, value := range rule.Params {
raw, ok := payloadRawString(value)
if !ok {
continue
}
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 || !json.Valid(trimmed) {
log.WithFields(log.Fields{
"section": section,
"rule_index": i + 1,
"param": path,
}).Warn("payload rule dropped: invalid raw JSON")
invalid = true
break
}
}
if invalid {
continue
}
out = append(out, rule)
}
return out
}
func payloadRawString(value any) ([]byte, bool) {
switch typed := value.(type) {
case string:
return []byte(typed), true
case []byte:
return typed, true
default:
return nil, false
}
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
func (cfg *Config) SanitizeOAuthModelAlias() {
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
clean := make([]OAuthModelAlias, 0, len(aliases))
for _, entry := range aliases {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
@@ -594,13 +723,13 @@ func (cfg *Config) SanitizeOAuthModelMappings() {
continue
}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelMappings = out
cfg.OAuthModelAlias = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are

View File

@@ -0,0 +1,275 @@
package config
import (
"os"
"strings"
"gopkg.in/yaml.v3"
)
// antigravityModelConversionTable maps old built-in aliases to actual model names
// for the antigravity channel during migration.
var antigravityModelConversionTable = map[string]string{
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
}
}
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
// to oauth-model-alias at startup. Returns true if migration was performed.
//
// Migration flow:
// 1. Check if oauth-model-alias exists -> skip migration
// 2. Check if oauth-model-mappings exists -> convert and migrate
// - For antigravity channel, convert old built-in aliases to actual model names
//
// 3. Neither exists -> add default antigravity config
func MigrateOAuthModelAlias(configFile string) (bool, error) {
data, err := os.ReadFile(configFile)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
if len(data) == 0 {
return false, nil
}
// Parse YAML into node tree to preserve structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
return false, nil
}
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
return false, nil
}
rootMap := root.Content[0]
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
return false, nil
}
// Check if oauth-model-alias already exists
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
return false, nil
}
// Check if oauth-model-mappings exists
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
if oldIdx >= 0 {
// Migrate from old field
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
}
// Neither field exists - add default antigravity config
return addDefaultAntigravityConfig(configFile, &root, rootMap)
}
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
if oldIdx+1 >= len(rootMap.Content) {
return false, nil
}
oldValue := rootMap.Content[oldIdx+1]
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
return false, nil
}
// Parse the old aliases
oldAliases := parseOldAliasNode(oldValue)
if len(oldAliases) == 0 {
// Remove the old field and write
removeMapKeyByIndex(rootMap, oldIdx)
return writeYAMLNode(configFile, root)
}
// Convert model names for antigravity channel
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
for channel, entries := range oldAliases {
converted := make([]OAuthModelAlias, 0, len(entries))
for _, entry := range entries {
newEntry := OAuthModelAlias{
Name: entry.Name,
Alias: entry.Alias,
Fork: entry.Fork,
}
// Convert model names for antigravity channel
if strings.EqualFold(channel, "antigravity") {
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
newEntry.Name = actual
}
}
converted = append(converted, newEntry)
}
newAliases[channel] = converted
}
// For antigravity channel, supplement missing default aliases
if antigravityEntries, exists := newAliases["antigravity"]; exists {
// Build a set of already configured model names (upstream names)
configuredModels := make(map[string]bool, len(antigravityEntries))
for _, entry := range antigravityEntries {
configuredModels[entry.Name] = true
}
// Add missing default aliases
for _, defaultAlias := range defaultAntigravityAliases() {
if !configuredModels[defaultAlias.Name] {
antigravityEntries = append(antigravityEntries, defaultAlias)
}
}
newAliases["antigravity"] = antigravityEntries
}
// Build new node
newNode := buildOAuthModelAliasNode(newAliases)
// Replace old key with new key and value
rootMap.Content[oldIdx].Value = "oauth-model-alias"
rootMap.Content[oldIdx+1] = newNode
return writeYAMLNode(configFile, root)
}
// addDefaultAntigravityConfig adds the default antigravity configuration
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
defaults := map[string][]OAuthModelAlias{
"antigravity": defaultAntigravityAliases(),
}
newNode := buildOAuthModelAliasNode(defaults)
// Add new key-value pair
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
rootMap.Content = append(rootMap.Content, keyNode, newNode)
return writeYAMLNode(configFile, root)
}
// parseOldAliasNode parses the old oauth-model-mappings node structure
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
if node == nil || node.Kind != yaml.MappingNode {
return nil
}
result := make(map[string][]OAuthModelAlias)
for i := 0; i+1 < len(node.Content); i += 2 {
channelNode := node.Content[i]
entriesNode := node.Content[i+1]
if channelNode == nil || entriesNode == nil {
continue
}
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
continue
}
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
for _, entryNode := range entriesNode.Content {
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
continue
}
entry := parseAliasEntry(entryNode)
if entry.Name != "" && entry.Alias != "" {
entries = append(entries, entry)
}
}
if len(entries) > 0 {
result[channel] = entries
}
}
return result
}
// parseAliasEntry parses a single alias entry node
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
var entry OAuthModelAlias
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valNode := node.Content[i+1]
if keyNode == nil || valNode == nil {
continue
}
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
case "name":
entry.Name = strings.TrimSpace(valNode.Value)
case "alias":
entry.Alias = strings.TrimSpace(valNode.Value)
case "fork":
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
}
}
return entry
}
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
for channel, entries := range aliases {
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
for _, entry := range entries {
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
)
if entry.Fork {
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
)
}
entriesNode.Content = append(entriesNode.Content, entryNode)
}
node.Content = append(node.Content, channelNode, entriesNode)
}
return node
}
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return
}
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
return
}
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
}
// writeYAMLNode writes the YAML node tree back to file
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
f, err := os.Create(configFile)
if err != nil {
return false, err
}
defer f.Close()
enc := yaml.NewEncoder(f)
enc.SetIndent(2)
if err := enc.Encode(root); err != nil {
return false, err
}
if err := enc.Close(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -0,0 +1,242 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-alias:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration when oauth-model-alias already exists")
}
// Verify file unchanged
data, _ := os.ReadFile(configFile)
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("file should still contain oauth-model-alias")
}
}
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-mappings:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
fork: true
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify new field exists and old field removed
data, _ := os.ReadFile(configFile)
if strings.Contains(string(data), "oauth-model-mappings:") {
t.Fatal("old field should be removed")
}
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("new field should exist")
}
// Parse and verify structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
t.Fatal(err)
}
}
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
// Use old model names that should be converted
content := `oauth-model-mappings:
antigravity:
- name: "gemini-2.5-computer-use-preview-10-2025"
alias: "computer-use"
- name: "gemini-3-pro-preview"
alias: "g3p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify model names were converted
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
}
if !strings.Contains(content, "gemini-3-pro-high") {
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
}
// Verify missing default aliases were supplemented
if !strings.Contains(content, "gemini-3-pro-image") {
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
}
if !strings.Contains(content, "gemini-3-flash") {
t.Fatal("expected missing default alias gemini-3-flash to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5") {
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
}
if !strings.Contains(content, "claude-opus-4-5-thinking") {
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
}
}
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to add default config")
}
// Verify default antigravity config was added
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "oauth-model-alias:") {
t.Fatal("expected oauth-model-alias to be added")
}
if !strings.Contains(content, "antigravity:") {
t.Fatal("expected antigravity channel to be added")
}
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
}
}
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
oauth-model-mappings:
gemini-cli:
- name: "test"
alias: "t"
api-keys:
- "key1"
- "key2"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify other config preserved
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "debug: true") {
t.Fatal("expected debug field to be preserved")
}
if !strings.Contains(content, "port: 8080") {
t.Fatal("expected port field to be preserved")
}
if !strings.Contains(content, "api-keys:") {
t.Fatal("expected api-keys field to be preserved")
}
}
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
t.Parallel()
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
if err != nil {
t.Fatalf("unexpected error for nonexistent file: %v", err)
}
if migrated {
t.Fatal("expected no migration for nonexistent file")
}
}
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration for empty file")
}
}

View File

@@ -0,0 +1,56 @@
package config
import "testing"
func TestSanitizeOAuthModelAlias_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["codex"]
if len(aliases) != 2 {
t.Fatalf("expected 2 sanitized aliases, got %d", len(aliases))
}
if aliases[0].Name != "gpt-5" || aliases[0].Alias != "g5" || !aliases[0].Fork {
t.Fatalf("expected first alias to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", aliases[0].Name, aliases[0].Alias, aliases[0].Fork)
}
if aliases[1].Name != "gpt-6" || aliases[1].Alias != "g6" || aliases[1].Fork {
t.Fatalf("expected second alias to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", aliases[1].Name, aliases[1].Alias, aliases[1].Fork)
}
}
func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"antigravity": {
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["antigravity"]
expected := []OAuthModelAlias{
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
}
if len(aliases) != len(expected) {
t.Fatalf("expected %d sanitized aliases, got %d", len(expected), len(aliases))
}
for i, exp := range expected {
if aliases[i].Name != exp.Name || aliases[i].Alias != exp.Alias || aliases[i].Fork != exp.Fork {
t.Fatalf("expected alias %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, aliases[i].Name, aliases[i].Alias, aliases[i].Fork)
}
}
}

View File

@@ -1,56 +0,0 @@
package config
import "testing"
func TestSanitizeOAuthModelMappings_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["codex"]
if len(mappings) != 2 {
t.Fatalf("expected 2 sanitized mappings, got %d", len(mappings))
}
if mappings[0].Name != "gpt-5" || mappings[0].Alias != "g5" || !mappings[0].Fork {
t.Fatalf("expected first mapping to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", mappings[0].Name, mappings[0].Alias, mappings[0].Fork)
}
if mappings[1].Name != "gpt-6" || mappings[1].Alias != "g6" || mappings[1].Fork {
t.Fatalf("expected second mapping to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", mappings[1].Name, mappings[1].Alias, mappings[1].Fork)
}
}
func TestSanitizeOAuthModelMappings_AllowsMultipleAliasesForSameName(t *testing.T) {
cfg := &Config{
OAuthModelMappings: map[string][]ModelNameMapping{
"antigravity": {
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
},
},
}
cfg.SanitizeOAuthModelMappings()
mappings := cfg.OAuthModelMappings["antigravity"]
expected := []ModelNameMapping{
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
}
if len(mappings) != len(expected) {
t.Fatalf("expected %d sanitized mappings, got %d", len(expected), len(mappings))
}
for i, exp := range expected {
if mappings[i].Name != exp.Name || mappings[i].Alias != exp.Alias || mappings[i].Fork != exp.Fork {
t.Fatalf("expected mapping %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, mappings[i].Name, mappings[i].Alias, mappings[i].Fork)
}
}
}

View File

@@ -13,6 +13,10 @@ type VertexCompatKey struct {
// Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -32,6 +36,9 @@ type VertexCompatKey struct {
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
}
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
func (k VertexCompatKey) GetBaseURL() string { return k.BaseURL }
// VertexCompatModel represents a model configuration for Vertex compatibility,
// including the actual model name and its alias for API routing.
type VertexCompatModel struct {

View File

@@ -4,6 +4,7 @@
package logging
import (
"errors"
"fmt"
"net/http"
"runtime/debug"
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
// - gin.HandlerFunc: A middleware handler for panic recovery
func GinLogrusRecovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
panic(http.ErrAbortHandler)
}
log.WithFields(log.Fields{
"panic": recovered,
"stack": string(debug.Stack()),

View File

@@ -0,0 +1,60 @@
package logging
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/abort", func(c *gin.Context) {
panic(http.ErrAbortHandler)
})
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
recorder := httptest.NewRecorder()
defer func() {
recovered := recover()
if recovered == nil {
t.Fatalf("expected panic, got nil")
}
err, ok := recovered.(error)
if !ok {
t.Fatalf("expected error panic, got %T", recovered)
}
if !errors.Is(err, http.ErrAbortHandler) {
t.Fatalf("expected ErrAbortHandler, got %v", err)
}
if err != http.ErrAbortHandler {
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
}
}()
engine.ServeHTTP(recorder, req)
}
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/panic", func(c *gin.Context) {
panic("boom")
})
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
recorder := httptest.NewRecorder()
engine.ServeHTTP(recorder, req)
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
}

View File

@@ -29,6 +29,9 @@ var (
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{}
// logFieldOrder defines the display order for common log fields.
var logFieldOrder = []string{"provider", "model", "mode", "budget", "level", "original_mode", "original_value", "min", "max", "clamped_to", "error"}
// Format renders a single log entry with custom formatting.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var buffer *bytes.Buffer
@@ -52,11 +55,25 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
}
levelStr := fmt.Sprintf("%-5s", level)
// Build fields string (only print fields in logFieldOrder)
var fieldsStr string
if len(entry.Data) > 0 {
var fields []string
for _, k := range logFieldOrder {
if v, ok := entry.Data[k]; ok {
fields = append(fields, fmt.Sprintf("%s=%v", k, v))
}
}
if len(fields) > 0 {
fieldsStr = " " + strings.Join(fields, " ")
}
}
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s%s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message, fieldsStr)
} else {
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] %s%s\n", timestamp, reqID, levelStr, message, fieldsStr)
}
buffer.WriteString(formatted)
@@ -105,6 +122,24 @@ func isDirWritable(dir string) bool {
return true
}
// ResolveLogDirectory determines the directory used for application logs.
func ResolveLogDirectory(cfg *config.Config) string {
logDir := "logs"
if base := util.WritablePath(); base != "" {
return filepath.Join(base, "logs")
}
if cfg == nil {
return logDir
}
if !isDirWritable(logDir) {
authDir := strings.TrimSpace(cfg.AuthDir)
if authDir != "" {
logDir = filepath.Join(authDir, "logs")
}
}
return logDir
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
@@ -114,12 +149,7 @@ func ConfigureLogOutput(cfg *config.Config) error {
writerMu.Lock()
defer writerMu.Unlock()
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
} else if !isDirWritable(logDir) {
logDir = filepath.Join(cfg.AuthDir, "logs")
}
logDir := ResolveLogDirectory(cfg)
protectedPath := ""
if cfg.LoggingToFile {

View File

@@ -7,11 +7,27 @@ import (
"embed"
_ "embed"
"strings"
"sync/atomic"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// codexInstructionsEnabled controls whether CodexInstructionsForModel returns official instructions.
// When false (default), CodexInstructionsForModel returns (true, "") immediately.
// Set via SetCodexInstructionsEnabled from config.
var codexInstructionsEnabled atomic.Bool
// SetCodexInstructionsEnabled sets whether codex instructions processing is enabled.
func SetCodexInstructionsEnabled(enabled bool) {
codexInstructionsEnabled.Store(enabled)
}
// GetCodexInstructionsEnabled returns whether codex instructions processing is enabled.
func GetCodexInstructionsEnabled() bool {
return codexInstructionsEnabled.Load()
}
//go:embed codex_instructions
var codexInstructionsDir embed.FS
@@ -124,6 +140,9 @@ func codexInstructionsForCodex(modelName, systemInstructions string) (bool, stri
}
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
if !GetCodexInstructionsEnabled() {
return true, ""
}
if IsOpenCodeUserAgent(userAgent) {
return codexInstructionsForOpenCode(systemInstructions)
}

View File

@@ -0,0 +1,303 @@
// Package registry provides Kiro model conversion utilities.
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
// and merging with static metadata for thinking support and other capabilities.
package registry
import (
"strings"
"time"
)
// KiroAPIModel represents a model from Kiro API response.
// This is a local copy to avoid import cycles with the kiro package.
// The structure mirrors kiro.KiroModel for easy data conversion.
type KiroAPIModel struct {
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
ModelID string
// ModelName is the human-readable name
ModelName string
// Description is the model description
Description string
// RateMultiplier is the credit multiplier for this model
RateMultiplier float64
// RateUnit is the unit for rate calculation (e.g., "credit")
RateUnit string
// MaxInputTokens is the maximum input token limit
MaxInputTokens int
}
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
// All Kiro models support thinking with the following budget range.
var DefaultKiroThinkingSupport = &ThinkingSupport{
Min: 1024, // Minimum thinking budget tokens
Max: 32000, // Maximum thinking budget tokens
ZeroAllowed: true, // Allow disabling thinking with 0
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
}
// DefaultKiroContextLength is the default context window size for Kiro models.
const DefaultKiroContextLength = 200000
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
const DefaultKiroMaxCompletionTokens = 64000
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
// It performs the following transformations:
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
// - Adds default thinking support metadata
// - Sets default context length and max completion tokens if not provided
//
// Parameters:
// - kiroModels: List of models from Kiro API response
//
// Returns:
// - []*ModelInfo: Converted model information list
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
if len(kiroModels) == 0 {
return nil
}
now := time.Now().Unix()
result := make([]*ModelInfo, 0, len(kiroModels))
for _, km := range kiroModels {
// Skip nil models
if km == nil {
continue
}
// Skip models without valid ID
if km.ModelID == "" {
continue
}
// Normalize the model ID to kiro-* format
normalizedID := normalizeKiroModelID(km.ModelID)
// Create ModelInfo with converted data
info := &ModelInfo{
ID: normalizedID,
Object: "model",
Created: now,
OwnedBy: "aws",
Type: "kiro",
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
Description: km.Description,
// Use MaxInputTokens from API if available, otherwise use default
ContextLength: getContextLength(km.MaxInputTokens),
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
// All Kiro models support thinking
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
}
result = append(result, info)
}
return result
}
// GenerateAgenticVariants creates -agentic variants for each model.
// Agentic variants are optimized for coding agents with chunked writes.
//
// Parameters:
// - models: Base models to generate variants for
//
// Returns:
// - []*ModelInfo: Combined list of base models and their agentic variants
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
// Pre-allocate result with capacity for both base models and variants
result := make([]*ModelInfo, 0, len(models)*2)
for _, model := range models {
if model == nil {
continue
}
// Add the base model first
result = append(result, model)
// Skip if model already has -agentic suffix
if strings.HasSuffix(model.ID, "-agentic") {
continue
}
// Skip special models that shouldn't have agentic variants
if model.ID == "kiro-auto" {
continue
}
// Create agentic variant
agenticModel := &ModelInfo{
ID: model.ID + "-agentic",
Object: model.Object,
Created: model.Created,
OwnedBy: model.OwnedBy,
Type: model.Type,
DisplayName: model.DisplayName + " (Agentic)",
Description: generateAgenticDescription(model.Description),
ContextLength: model.ContextLength,
MaxCompletionTokens: model.MaxCompletionTokens,
Thinking: cloneThinkingSupport(model.Thinking),
}
result = append(result, agenticModel)
}
return result
}
// MergeWithStaticMetadata merges dynamic models with static metadata.
// Static metadata takes priority for any overlapping fields.
// This allows manual overrides for specific models while keeping dynamic discovery.
//
// Parameters:
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
// - staticModels: Predefined model metadata (from GetKiroModels())
//
// Returns:
// - []*ModelInfo: Merged model list with static metadata taking priority
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
if len(dynamicModels) == 0 && len(staticModels) == 0 {
return nil
}
// Build a map of static models for quick lookup
staticMap := make(map[string]*ModelInfo, len(staticModels))
for _, sm := range staticModels {
if sm != nil && sm.ID != "" {
staticMap[sm.ID] = sm
}
}
// Build result, preferring static metadata where available
seenIDs := make(map[string]struct{})
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
// First, process dynamic models and merge with static if available
for _, dm := range dynamicModels {
if dm == nil || dm.ID == "" {
continue
}
// Skip duplicates
if _, seen := seenIDs[dm.ID]; seen {
continue
}
seenIDs[dm.ID] = struct{}{}
// Check if static metadata exists for this model
if sm, exists := staticMap[dm.ID]; exists {
// Static metadata takes priority - use static model
result = append(result, sm)
} else {
// No static metadata - use dynamic model
result = append(result, dm)
}
}
// Add any static models not in dynamic list
for _, sm := range staticModels {
if sm == nil || sm.ID == "" {
continue
}
if _, seen := seenIDs[sm.ID]; seen {
continue
}
seenIDs[sm.ID] = struct{}{}
result = append(result, sm)
}
return result
}
// normalizeKiroModelID converts Kiro API model IDs to internal format.
// Transformation rules:
// - Adds "kiro-" prefix if not present
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
// - Handles special cases like "auto" → "kiro-auto"
//
// Examples:
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
// - "auto" → "kiro-auto"
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
func normalizeKiroModelID(modelID string) string {
if modelID == "" {
return ""
}
// Trim whitespace
modelID = strings.TrimSpace(modelID)
// Replace dots with hyphens (e.g., 4.5 → 4-5)
normalized := strings.ReplaceAll(modelID, ".", "-")
// Add kiro- prefix if not present
if !strings.HasPrefix(normalized, "kiro-") {
normalized = "kiro-" + normalized
}
return normalized
}
// generateKiroDisplayName creates a human-readable display name.
// Uses the API-provided model name if available, otherwise generates from ID.
func generateKiroDisplayName(modelName, normalizedID string) string {
if modelName != "" {
return "Kiro " + modelName
}
// Generate from normalized ID by removing kiro- prefix and formatting
displayID := strings.TrimPrefix(normalizedID, "kiro-")
// Capitalize first letter of each word
words := strings.Split(displayID, "-")
for i, word := range words {
if len(word) > 0 {
words[i] = strings.ToUpper(word[:1]) + word[1:]
}
}
return "Kiro " + strings.Join(words, " ")
}
// generateAgenticDescription creates description for agentic variants.
func generateAgenticDescription(baseDescription string) string {
if baseDescription == "" {
return "Optimized for coding agents with chunked writes"
}
return baseDescription + " (Agentic mode: chunked writes)"
}
// getContextLength returns the context length, using default if not provided.
func getContextLength(maxInputTokens int) int {
if maxInputTokens > 0 {
return maxInputTokens
}
return DefaultKiroContextLength
}
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
// Returns nil if input is nil.
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
if ts == nil {
return nil
}
clone := &ThinkingSupport{
Min: ts.Min,
Max: ts.Max,
ZeroAllowed: ts.ZeroAllowed,
DynamicAllowed: ts.DynamicAllowed,
}
// Deep copy Levels slice if present
if len(ts.Levels) > 0 {
clone.Levels = make([]string, len(ts.Levels))
copy(clone.Levels, ts.Levels)
}
return clone
}

View File

@@ -27,7 +27,7 @@ func GetClaudeModels() []*ModelInfo {
DisplayName: "Claude 4.5 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
@@ -39,7 +39,7 @@ func GetClaudeModels() []*ModelInfo {
Description: "Premium model combining maximum intelligence with practical performance",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-1-20250805",
@@ -50,7 +50,7 @@ func GetClaudeModels() []*ModelInfo {
DisplayName: "Claude 4.1 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-opus-4-20250514",
@@ -61,7 +61,7 @@ func GetClaudeModels() []*ModelInfo {
DisplayName: "Claude 4 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-20250514",
@@ -72,7 +72,7 @@ func GetClaudeModels() []*ModelInfo {
DisplayName: "Claude 4 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-3-7-sonnet-20250219",
@@ -83,7 +83,7 @@ func GetClaudeModels() []*ModelInfo {
DisplayName: "Claude 3.7 Sonnet",
ContextLength: 128000,
MaxCompletionTokens: 8192,
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-3-5-haiku-20241022",
@@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
// Imagen image generation models - use :predict action
{
ID: "imagen-4.0-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Generate",
Description: "Imagen 4.0 image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-4.0-ultra-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-ultra-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Ultra Generate",
Description: "Imagen 4.0 Ultra high-quality image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-3.0-generate-002",
Object: "model",
Created: 1740000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-3.0-generate-002",
Version: "3.0",
DisplayName: "Imagen 3.0 Generate",
Description: "Imagen 3.0 image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-3.0-fast-generate-001",
Object: "model",
Created: 1740000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-3.0-fast-generate-001",
Version: "3.0",
DisplayName: "Imagen 3.0 Fast Generate",
Description: "Imagen 3.0 fast image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-4.0-fast-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-fast-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Fast Generate",
Description: "Imagen 4.0 fast image generation model",
SupportedGenerationMethods: []string{"predict"},
},
}
}
@@ -432,7 +493,7 @@ func GetAIStudioModels() []*ModelInfo {
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
@@ -447,7 +508,7 @@ func GetAIStudioModels() []*ModelInfo {
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-pro-latest",
@@ -742,6 +803,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -764,21 +826,23 @@ func GetIFlowModels() []*ModelInfo {
type AntigravityModelConfig struct {
Thinking *ThinkingSupport
MaxCompletionTokens int
Name string
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
// Keys use upstream model names returned by the Antigravity models endpoint.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
"gemini-2.5-computer-use-preview-10-2025": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}
}
@@ -788,6 +852,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
@@ -805,6 +870,16 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
}
}
}
// Check Antigravity static config
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
return &ModelInfo{
ID: modelID,
Thinking: cfg.Thinking,
MaxCompletionTokens: cfg.MaxCompletionTokens,
}
}
return nil
}
@@ -834,6 +909,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5-mini",
@@ -845,6 +921,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5 Mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5-codex",
@@ -856,6 +933,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
},
{
ID: "gpt-5.1",
@@ -867,6 +945,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5.1 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5.1-codex",
@@ -878,6 +957,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5.1 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
},
{
ID: "gpt-5.1-codex-mini",
@@ -889,6 +969,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/responses"},
},
{
ID: "gpt-5.1-codex-max",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.1 Codex Max",
Description: "OpenAI GPT-5.1 Codex Max via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
},
{
ID: "gpt-5.2",
@@ -900,6 +993,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-5.2 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
{
ID: "gpt-5.2-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.2 Codex",
Description: "OpenAI GPT-5.2 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
},
{
ID: "claude-haiku-4.5",
@@ -911,6 +1017,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.1",
@@ -922,6 +1029,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.5",
@@ -933,6 +1041,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4",
@@ -944,6 +1053,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.5",
@@ -955,6 +1065,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-2.5-pro",
@@ -968,13 +1079,24 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-pro",
ID: "gemini-3-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3 Pro",
Description: "Google Gemini 3 Pro via GitHub Copilot",
DisplayName: "Gemini 3 Pro (Preview)",
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3 Flash (Preview)",
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
@@ -990,15 +1112,16 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 16384,
},
{
ID: "raptor-mini",
ID: "oswe-vscode-prime",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Raptor Mini",
Description: "Raptor Mini via GitHub Copilot",
DisplayName: "Raptor mini (Preview)",
Description: "Raptor mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}
}
@@ -1007,6 +1130,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
func GetKiroModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kiro-auto",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Auto",
Description: "Automatic model selection by Kiro",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5",
Object: "model",

View File

@@ -47,10 +47,17 @@ type ModelInfo struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
Thinking *ThinkingSupport `json:"thinking,omitempty"`
// UserDefined indicates this model was defined through config file's models[]
// array (e.g., openai-compatibility.*.models[], *-api-key.models[]).
// UserDefined models have thinking configuration passed through without validation.
UserDefined bool `json:"-"`
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
@@ -73,6 +80,8 @@ type ThinkingSupport struct {
type ModelRegistration struct {
// Info contains the model metadata
Info *ModelInfo
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
InfoByProvider map[string]*ModelInfo
// Count is the number of active clients that can provide this model
Count int
// LastUpdated tracks when this registration was last modified
@@ -127,6 +136,24 @@ func GetGlobalRegistry() *ModelRegistry {
return globalRegistry
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return nil
}
p := ""
if len(provider) > 0 {
p = strings.ToLower(strings.TrimSpace(provider[0]))
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info
}
return LookupStaticModelInfo(modelID)
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
@@ -277,6 +304,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= toRemove {
delete(reg.Providers, oldProvider)
if reg.InfoByProvider != nil {
delete(reg.InfoByProvider, oldProvider)
}
} else {
reg.Providers[oldProvider] = count - toRemove
}
@@ -326,6 +356,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
model := newModels[id]
if reg, ok := r.models[id]; ok {
reg.Info = cloneModelInfo(model)
if provider != "" {
if reg.InfoByProvider == nil {
reg.InfoByProvider = make(map[string]*ModelInfo)
}
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
@@ -389,11 +425,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if existing.InfoByProvider == nil {
existing.InfoByProvider = make(map[string]*ModelInfo)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
existing.InfoByProvider[provider] = cloneModelInfo(model)
}
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
return
@@ -401,6 +441,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
registration := &ModelRegistration{
Info: cloneModelInfo(model),
InfoByProvider: make(map[string]*ModelInfo),
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
@@ -408,6 +449,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
registration.InfoByProvider[provider] = cloneModelInfo(model)
}
r.models[modelID] = registration
log.Debugf("Registered new model %s from provider %s", modelID, provider)
@@ -433,6 +475,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
@@ -456,6 +501,9 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
}
return &copyModel
}
@@ -514,6 +562,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
@@ -920,12 +971,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
return result
}
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
// Returns nil if the model is unknown to the registry.
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
r.mutex.RLock()
defer r.mutex.RUnlock()
if reg, ok := r.models[modelID]; ok && reg != nil {
// Try provider specific definition first
if provider != "" && reg.InfoByProvider != nil {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info
}
}
}
}
// Fallback to global info (last registered)
return reg.Info
}
return nil
@@ -968,6 +1029,9 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
if len(model.SupportedEndpoints) > 0 {
result["supported_endpoints"] = model.SupportedEndpoints
}
return result
case "claude", "kiro", "antigravity":

View File

@@ -14,7 +14,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -111,7 +111,8 @@ func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.A
// Execute performs a non-streaming request to the AI Studio API.
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
@@ -119,7 +120,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -166,7 +167,8 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
@@ -174,7 +176,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -315,6 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
// CountTokens counts tokens for the given request using the AI Studio API.
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
@@ -324,7 +327,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
endpoint := e.buildEndpoint(baseModel, "countTokens", "")
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
@@ -380,22 +383,22 @@ type translatedPayload struct {
}
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, translatedPayload{}, err
}
payload = fixGeminiImageAspectRatio(baseModel, payload)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")

View File

@@ -24,7 +24,9 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -107,8 +109,10 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if isClaude || strings.Contains(req.Model, "gemini-3-pro") {
baseModel := thinking.ParseSuffix(req.Model).ModelName
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
if isClaude || strings.Contains(baseModel, "gemini-3-pro") {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
@@ -120,23 +124,25 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -146,7 +152,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -227,6 +233,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return resp, errToken
@@ -235,23 +243,25 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, true)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -261,7 +271,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
@@ -507,8 +517,8 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
usageRaw = usageResult.Raw
} else if usageResult := root.Get("usageMetadata"); usageResult.Exists() {
usageRaw = usageResult.Raw
} else if usageMetadataResult := root.Get("usageMetadata"); usageMetadataResult.Exists() {
usageRaw = usageMetadataResult.Raw
}
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
@@ -587,6 +597,8 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
ctx = context.WithValue(ctx, "alt", "")
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
@@ -597,25 +609,25 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
translated = ApplyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, translated)
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated, originalTranslated)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -625,12 +637,11 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return nil, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
@@ -771,6 +782,8 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
// CountTokens counts tokens for the given request using the Antigravity API.
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return cliproxyexecutor.Response{}, errToken
@@ -786,7 +799,17 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
// Prepare payload once (doesn't depend on baseURL)
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -803,14 +826,6 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
var lastErr error
for idx, baseURL := range baseURLs {
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, payload)
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
base := strings.TrimSuffix(baseURL, "/")
if base == "" {
base = buildBaseURL(auth)
@@ -980,35 +995,37 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName := range result.Map() {
aliasName := modelName2Alias(originalName)
if aliasName != "" {
cfg := modelConfig[aliasName]
modelName := aliasName
if cfg != nil && cfg.Name != "" {
modelName = cfg.Name
}
modelInfo := &registry.ModelInfo{
ID: aliasName,
Name: modelName,
Description: aliasName,
DisplayName: aliasName,
Version: aliasName,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Look up Thinking support from static config using alias name
if cfg != nil {
if cfg.Thinking != nil {
modelInfo.Thinking = cfg.Thinking
}
if cfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
switch modelID {
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
continue
}
modelCfg := modelConfig[modelID]
modelName := modelID
modelInfo := &registry.ModelInfo{
ID: modelID,
Name: modelName,
Description: modelID,
DisplayName: modelID,
Version: modelID,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Look up Thinking support from static config using upstream model name.
if modelCfg != nil {
if modelCfg.Thinking != nil {
modelInfo.Thinking = modelCfg.Thinking
}
if modelCfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = modelCfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
}
return models
}
@@ -1104,12 +1121,49 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
auth.Metadata["refresh_token"] = tokenResp.RefreshToken
}
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = time.Now().UnixMilli()
auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
now := time.Now()
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
auth.Metadata["type"] = antigravityAuthType
if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil {
log.Warnf("antigravity executor: ensure project id failed: %v", errProject)
}
return auth, nil
}
func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) error {
if auth == nil {
return nil
}
if auth.Metadata["project_id"] != nil {
return nil
}
token := strings.TrimSpace(accessToken)
if token == "" {
token = metaStringValue(auth.Metadata, "access_token")
}
if token == "" {
return nil
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
if errFetch != nil {
return errFetch
}
if strings.TrimSpace(projectID) == "" {
return nil
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
return nil
}
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
if token == "" {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
@@ -1146,9 +1200,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
}
}
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
payload, _ = sjson.SetBytes(payload, "model", modelName)
if strings.Contains(modelName, "claude") {
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
strJSON := string(payload)
paths := make([]string, 0)
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
@@ -1163,7 +1217,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payload = []byte(strJSON)
}
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-preview") {
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
@@ -1351,16 +1405,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
// template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
if !strings.HasPrefix(modelName, "gemini-3-") {
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
}
if strings.Contains(modelName, "claude") {
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
if funcDecl.Get("parametersJsonSchema").Exists() {
@@ -1372,7 +1419,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
})
return true
})
} else {
}
if !strings.Contains(modelName, "claude") {
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
}
@@ -1417,108 +1466,3 @@ func generateProjectID() string {
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
func modelName2Alias(modelName string) string {
switch modelName {
case "rev19-uic3-1p":
return "gemini-2.5-computer-use-preview-10-2025"
case "gemini-3-pro-image":
return "gemini-3-pro-image-preview"
case "gemini-3-pro-high":
return "gemini-3-pro-preview"
case "gemini-3-flash":
return "gemini-3-flash-preview"
case "claude-sonnet-4-5":
return "gemini-claude-sonnet-4-5"
case "claude-sonnet-4-5-thinking":
return "gemini-claude-sonnet-4-5-thinking"
case "claude-opus-4-5-thinking":
return "gemini-claude-opus-4-5-thinking"
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
return ""
default:
return modelName
}
}
func alias2ModelName(modelName string) string {
switch modelName {
case "gemini-2.5-computer-use-preview-10-2025":
return "rev19-uic3-1p"
case "gemini-3-pro-image-preview":
return "gemini-3-pro-image"
case "gemini-3-pro-preview":
return "gemini-3-pro-high"
case "gemini-3-flash-preview":
return "gemini-3-flash"
case "gemini-claude-sonnet-4-5":
return "claude-sonnet-4-5"
case "gemini-claude-sonnet-4-5-thinking":
return "claude-sonnet-4-5-thinking"
case "gemini-claude-opus-4-5-thinking":
return "claude-opus-4-5-thinking"
default:
return modelName
}
}
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
}
budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget")
if !budget.Exists() {
return payload
}
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {
normalized = effectiveMax - 1
}
minBudget := antigravityMinThinkingBudget(model)
if minBudget > 0 && normalized >= 0 && normalized < minBudget {
// Budget is below minimum, remove thinking config entirely
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
return payload
}
if setDefaultMax {
if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil {
payload = res
}
}
}
updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
if err != nil {
return payload
}
return updated
}
// antigravityEffectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}
// antigravityMinThinkingBudget returns the minimum thinking budget for a model.
// Falls back to -1 if no model info is found.
func antigravityMinThinkingBudget(model string) int {
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.Thinking != nil {
return modelInfo.Thinking.Min
}
return -1
}

View File

@@ -17,7 +17,7 @@ import (
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -84,17 +84,15 @@ func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
}
func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiKey, baseURL := claudeCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
@@ -103,23 +101,24 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(model, req.Metadata, body)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", baseModel)
if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -218,37 +217,39 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiKey, baseURL := claudeCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", model)
// Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -381,8 +382,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, baseURL := claudeCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
@@ -391,14 +393,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", model)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
body, _ = sjson.SetBytes(body, "model", baseModel)
if !strings.HasPrefix(model, "claude-3-5-haiku") {
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
@@ -527,17 +525,6 @@ func extractAndRemoveBetas(body []byte) ([]string, []byte) {
return betas, body
}
// injectThinkingConfig adds thinking configuration based on metadata using the unified flow.
// It uses util.ResolveClaudeThinkingConfig which internally calls ResolveThinkingConfigFromMetadata
// and NormalizeThinkingBudget, ensuring consistency with other executors like Gemini.
func (e *ClaudeExecutor) injectThinkingConfig(modelName string, metadata map[string]any, body []byte) []byte {
budget, ok := util.ResolveClaudeThinkingConfig(modelName, metadata)
if !ok {
return body
}
return util.ApplyClaudeThinkingConfig(body, budget)
}
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
@@ -551,126 +538,6 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
return body
}
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
// Anthropic API requires this constraint; violating it returns a 400 error.
// This function should be called after all thinking configuration is finalized.
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if thinkingType != "enabled" {
return body
}
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
if budgetTokens <= 0 {
return body
}
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
// Look up the model's max completion tokens from the registry
maxCompletionTokens := 0
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(modelName); modelInfo != nil {
maxCompletionTokens = modelInfo.MaxCompletionTokens
}
// Fall back to budget + buffer if registry lookup fails or returns 0
const fallbackBuffer = 4000
requiredMaxTokens := budgetTokens + fallbackBuffer
if maxCompletionTokens > 0 {
requiredMaxTokens = int64(maxCompletionTokens)
}
if maxTokens < requiredMaxTokens {
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
}
return body
}
func (e *ClaudeExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveClaudeConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *ClaudeExecutor) resolveClaudeConfig(auth *cliproxyauth.Auth) *config.ClaudeKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.ClaudeKey {
entry := &e.cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.ClaudeKey {
entry := &e.cfg.ClaudeKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
type compositeReadCloser struct {
io.Reader
closers []func() error
@@ -956,3 +823,163 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
}
return updated
}
// getClientUserAgent extracts the client User-Agent from the gin context.
func getClientUserAgent(ctx context.Context) string {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
return ginCtx.GetHeader("User-Agent")
}
return ""
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil
}
cloakMode := auth.Attributes["cloak_mode"]
if cloakMode == "" {
cloakMode = "auto"
}
strictMode := strings.ToLower(auth.Attributes["cloak_strict_mode"]) == "true"
var sensitiveWords []string
if wordsStr := auth.Attributes["cloak_sensitive_words"]; wordsStr != "" {
sensitiveWords = strings.Split(wordsStr, ",")
for i := range sensitiveWords {
sensitiveWords[i] = strings.TrimSpace(sensitiveWords[i])
}
}
return cloakMode, strictMode, sensitiveWords
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
// Match by API key
if strings.EqualFold(cfgKey, apiKey) {
// If baseURL is specified, also check it
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry.Cloak
}
}
return nil
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
func injectFakeUserID(payload []byte) []byte {
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
}
return payload
}
// checkSystemInstructionsWithMode injects Claude Code system prompt.
// In strict mode, it replaces all user system messages.
// In non-strict mode (default), it prepends to existing system messages.
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
system := gjson.GetBytes(payload, "system")
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
if strictMode {
// Strict mode: replace all system messages with Claude Code prompt only
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
return payload
}
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
if system.IsArray() {
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
}
return true
})
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
} else {
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
return payload
}
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
// Determine cloak settings
var cloakMode string
var strictMode bool
var sensitiveWords []string
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
}
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
}
// Determine if cloaking should be applied
if !shouldCloak(cloakMode, clientUserAgent) {
return payload
}
// Skip system instructions for claude-3-5-haiku models
if !strings.HasPrefix(model, "claude-3-5-haiku") {
payload = checkSystemInstructionsWithMode(payload, strictMode)
}
// Inject fake user ID
payload = injectFakeUserID(payload)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {
matcher := buildSensitiveWordMatcher(sensitiveWords)
payload = obfuscateSensitiveWords(payload, matcher)
}
return payload
}

View File

@@ -0,0 +1,176 @@
package executor
import (
"regexp"
"sort"
"strings"
"unicode/utf8"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
const zeroWidthSpace = "\u200B"
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
type SensitiveWordMatcher struct {
regex *regexp.Regexp
}
// buildSensitiveWordMatcher compiles a regex from the word list.
// Words are sorted by length (longest first) for proper matching.
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
if len(words) == 0 {
return nil
}
// Filter and normalize words
var validWords []string
for _, w := range words {
w = strings.TrimSpace(w)
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
validWords = append(validWords, w)
}
}
if len(validWords) == 0 {
return nil
}
// Sort by length (longest first) for proper matching
sort.Slice(validWords, func(i, j int) bool {
return len(validWords[i]) > len(validWords[j])
})
// Escape and join
escaped := make([]string, len(validWords))
for i, w := range validWords {
escaped[i] = regexp.QuoteMeta(w)
}
pattern := "(?i)" + strings.Join(escaped, "|")
re, err := regexp.Compile(pattern)
if err != nil {
return nil
}
return &SensitiveWordMatcher{regex: re}
}
// obfuscateWord inserts a zero-width space after the first grapheme.
func obfuscateWord(word string) string {
if strings.Contains(word, zeroWidthSpace) {
return word
}
// Get first rune
r, size := utf8.DecodeRuneInString(word)
if r == utf8.RuneError || size >= len(word) {
return word
}
return string(r) + zeroWidthSpace + word[size:]
}
// obfuscateText replaces all sensitive words in the text.
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
if m == nil || m.regex == nil {
return text
}
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
}
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
// in system blocks and message content.
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
if matcher == nil || matcher.regex == nil {
return payload
}
// Obfuscate in system blocks
payload = obfuscateSystemBlocks(payload, matcher)
// Obfuscate in messages
payload = obfuscateMessages(payload, matcher)
return payload
}
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
system := gjson.GetBytes(payload, "system")
if !system.Exists() {
return payload
}
if system.IsArray() {
modified := false
system.ForEach(func(key, value gjson.Result) bool {
if value.Get("type").String() == "text" {
text := value.Get("text").String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
path := "system." + key.String() + ".text"
payload, _ = sjson.SetBytes(payload, path, obfuscated)
modified = true
}
}
return true
})
if modified {
return payload
}
} else if system.Type == gjson.String {
text := system.String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
}
}
return payload
}
// obfuscateMessages obfuscates sensitive words in message content.
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
messages := gjson.GetBytes(payload, "messages")
if !messages.Exists() || !messages.IsArray() {
return payload
}
messages.ForEach(func(msgKey, msg gjson.Result) bool {
content := msg.Get("content")
if !content.Exists() {
return true
}
msgPath := "messages." + msgKey.String()
if content.Type == gjson.String {
// Simple string content
text := content.String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
}
} else if content.IsArray() {
// Array of content blocks
content.ForEach(func(blockKey, block gjson.Result) bool {
if block.Get("type").String() == "text" {
text := block.Get("text").String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
path := msgPath + ".content." + blockKey.String() + ".text"
payload, _ = sjson.SetBytes(payload, path, obfuscated)
}
}
return true
})
}
return true
})
return payload
}

View File

@@ -0,0 +1,47 @@
package executor
import (
"crypto/rand"
"encoding/hex"
"regexp"
"strings"
"github.com/google/uuid"
)
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
// generateFakeUserID generates a fake user ID in Claude Code format.
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
func generateFakeUserID() string {
hexBytes := make([]byte, 32)
_, _ = rand.Read(hexBytes)
hexPart := hex.EncodeToString(hexBytes)
uuidPart := uuid.New().String()
return "user_" + hexPart + "_account__session_" + uuidPart
}
// isValidUserID checks if a user ID matches Claude Code format.
func isValidUserID(userID string) bool {
return userIDPattern.MatchString(userID)
}
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
// Returns true if cloaking should be applied.
func shouldCloak(cloakMode string, userAgent string) bool {
switch strings.ToLower(cloakMode) {
case "always":
return true
case "never":
return false
default: // "auto" or empty
// If client is Claude Code, don't cloak
return !strings.HasPrefix(userAgent, "claude-cli")
}
}
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
func isClaudeCodeClient(userAgent string) bool {
return strings.HasPrefix(userAgent, "claude-cli")
}

View File

@@ -13,6 +13,7 @@ import (
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -72,18 +73,15 @@ func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
}
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiKey, baseURL := codexCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -93,20 +91,25 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, false)
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
body = misc.StripCodexUserAgent(body)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -182,18 +185,15 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiKey, baseURL := codexCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
@@ -203,20 +203,24 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalPayload = misc.InjectCodexUserAgent(originalPayload, userAgent)
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, true)
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
body = misc.StripCodexUserAgent(body)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -303,25 +307,30 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
userAgent := codexUserAgent(ctx)
body := misc.InjectCodexUserAgent(bytes.Clone(req.Payload), userAgent)
body = sdktranslator.TranslateRequest(from, to, model, body, false)
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
body = misc.StripCodexUserAgent(body)
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
enc, err := tokenizerForCodexModel(model)
enc, err := tokenizerForCodexModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
@@ -593,51 +602,6 @@ func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
return
}
func (e *CodexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveCodexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil

View File

@@ -20,6 +20,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -102,28 +103,33 @@ func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.
// Execute performs a non-streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return resp, err
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
action := "generateContent"
if req.Metadata != nil {
@@ -133,9 +139,9 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(req.Model)
if len(models) == 0 || models[0] != req.Model {
models = append([]string{req.Model}, models...)
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -246,34 +252,39 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return nil, err
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = ApplyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, req.Metadata, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload, originalTranslated)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(req.Model)
if len(models) == 0 || models[0] != req.Model {
models = append([]string{req.Model}, models...)
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -435,6 +446,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
// CountTokens counts tokens for the given request using the Gemini CLI API.
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return cliproxyexecutor.Response{}, err
@@ -443,9 +456,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
models := cliPreviewFallbackOrder(req.Model)
if len(models) == 0 || models[0] != req.Model {
models = append([]string{req.Model}, models...)
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
@@ -463,15 +476,18 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = ApplyThinkingMetadataCLI(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
for range models {
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(baseModel, payload)
tok, errTok := tokenSource.Token()
if errTok != nil {

View File

@@ -13,6 +13,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -102,16 +103,13 @@ func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Aut
// - cliproxyexecutor.Response: The response from the API
// - error: An error if the request fails
func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
@@ -119,15 +117,17 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent"
if req.Metadata != nil {
@@ -136,7 +136,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
}
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -206,34 +206,33 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
@@ -331,27 +330,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// CountTokens counts tokens for the given request using the Gemini API.
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
model := req.Model
if override := e.resolveUpstreamModel(model, auth); override != "" {
model = override
}
apiKey, bearer := geminiCreds(auth)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "countTokens")
requestBody := bytes.NewReader(translatedReq)
@@ -450,51 +450,6 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
return base
}
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveGeminiConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil

View File

@@ -12,10 +12,11 @@ import (
"io"
"net/http"
"strings"
"time"
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -31,6 +32,143 @@ const (
vertexAPIVersion = "v1"
)
// isImagenModel checks if the model name is an Imagen image generation model.
// Imagen models use the :predict action instead of :generateContent.
func isImagenModel(model string) bool {
lowerModel := strings.ToLower(model)
return strings.Contains(lowerModel, "imagen")
}
// getVertexAction returns the appropriate action for the given model.
// Imagen models use "predict", while Gemini models use "generateContent".
func getVertexAction(model string, isStream bool) string {
if isImagenModel(model) {
return "predict"
}
if isStream {
return "streamGenerateContent"
}
return "generateContent"
}
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
// so it can be processed by the standard translation pipeline.
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
func convertImagenToGeminiResponse(data []byte, model string) []byte {
predictions := gjson.GetBytes(data, "predictions")
if !predictions.Exists() || !predictions.IsArray() {
return data
}
// Build Gemini-compatible response with inlineData
parts := make([]map[string]any, 0)
for _, pred := range predictions.Array() {
imageData := pred.Get("bytesBase64Encoded").String()
mimeType := pred.Get("mimeType").String()
if mimeType == "" {
mimeType = "image/png"
}
if imageData != "" {
parts = append(parts, map[string]any{
"inlineData": map[string]any{
"mimeType": mimeType,
"data": imageData,
},
})
}
}
// Generate unique response ID using timestamp
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
response := map[string]any{
"candidates": []map[string]any{{
"content": map[string]any{
"parts": parts,
"role": "model",
},
"finishReason": "STOP",
}},
"responseId": responseId,
"modelVersion": model,
// Imagen API doesn't return token counts, set to 0 for tracking purposes
"usageMetadata": map[string]any{
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
},
}
result, err := json.Marshal(response)
if err != nil {
return data
}
return result
}
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
// Imagen API uses a different structure: instances[].prompt instead of contents[].
func convertToImagenRequest(payload []byte) ([]byte, error) {
// Extract prompt from Gemini-style contents
prompt := ""
// Try to get prompt from contents[0].parts[0].text
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
if contentsText.Exists() {
prompt = contentsText.String()
}
// If no contents, try messages format (OpenAI-compatible)
if prompt == "" {
messagesText := gjson.GetBytes(payload, "messages.#.content")
if messagesText.Exists() && messagesText.IsArray() {
for _, msg := range messagesText.Array() {
if msg.String() != "" {
prompt = msg.String()
break
}
}
}
}
// If still no prompt, try direct prompt field
if prompt == "" {
directPrompt := gjson.GetBytes(payload, "prompt")
if directPrompt.Exists() {
prompt = directPrompt.String()
}
}
if prompt == "" {
return nil, fmt.Errorf("imagen: no prompt found in request")
}
// Build Imagen API request
imagenReq := map[string]any{
"instances": []map[string]any{
{
"prompt": prompt,
},
},
"parameters": map[string]any{
"sampleCount": 1,
},
}
// Extract optional parameters
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
}
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
}
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
}
return json.Marshal(imagenReq)
}
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
type GeminiVertexExecutor struct {
cfg *config.Config
@@ -155,39 +293,50 @@ func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Aut
// executeWithServiceAccount handles authentication using service account credentials.
// This method contains the original service account authentication logic.
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
var body []byte
action := "generateContent"
// Handle Imagen models with special request format
if isImagenModel(baseModel) {
imagenBody, errImagen := convertToImagenRequest(req.Payload)
if errImagen != nil {
return resp, errImagen
}
body = imagenBody
} else {
// Standard Gemini translation flow
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
}
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -250,6 +399,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
// For Imagen models, convert response to Gemini format before translation
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
if isImagenModel(baseModel) {
data = convertImagenToGeminiResponse(data, baseModel)
}
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
@@ -258,37 +417,31 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
// executeWithAPIKey handles authentication using API key credentials.
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
baseModel := thinking.ParseSuffix(req.Model).ModelName
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
action := "generateContent"
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
@@ -299,7 +452,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
@@ -367,37 +520,40 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", req.Model)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
// Imagen models don't support streaming, skip SSE params
if !isImagenModel(baseModel) {
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
}
body, _ = sjson.DeleteBytes(body, "session_id")
@@ -487,45 +643,43 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
baseModel := thinking.ParseSuffix(req.Model).ModelName
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfigWithRoot(e.cfg, model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", model)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params
if !isImagenModel(baseModel) {
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
}
body, _ = sjson.DeleteBytes(body, "session_id")
@@ -612,26 +766,27 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -688,10 +843,6 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
@@ -699,24 +850,20 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -726,7 +873,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
@@ -780,10 +927,6 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
@@ -870,53 +1013,6 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
return tok.AccessToken, nil
}
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {

View File

@@ -23,6 +23,7 @@ import (
const (
githubCopilotBaseURL = "https://api.githubcopilot.com"
githubCopilotChatPath = "/chat/completions"
githubCopilotResponsesPath = "/responses"
githubCopilotAuthType = "github-copilot"
githubCopilotTokenCacheTTL = 25 * time.Minute
// tokenExpiryBuffer is the time before expiry when we should refresh the token.
@@ -106,7 +107,11 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
@@ -117,7 +122,11 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "stream", false)
url := githubCopilotBaseURL + githubCopilotChatPath
path := githubCopilotChatPath
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
@@ -172,6 +181,9 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
appendAPIResponseChunk(ctx, e.cfg, data)
detail := parseOpenAIUsage(data)
if useResponses && detail.TotalTokens == 0 {
detail = parseOpenAIResponsesUsage(data)
}
if detail.TotalTokens > 0 {
reporter.publish(ctx, detail)
}
@@ -194,7 +206,11 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
@@ -205,9 +221,15 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body, _ = sjson.SetBytes(body, "stream", true)
// Enable stream options for usage stats in stream
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
if !useResponses {
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
}
url := githubCopilotBaseURL + githubCopilotChatPath
path := githubCopilotChatPath
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -283,6 +305,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
} else if useResponses {
if detail, ok := parseOpenAIResponsesStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
}
}
@@ -393,6 +419,10 @@ func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte {
return body
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
return sourceFormat.String() == "openai-response"
}
// isHTTPSuccess checks if the status code indicates success (2xx).
func isHTTPSuccess(statusCode int) bool {
return statusCode >= 200 && statusCode < 300

View File

@@ -12,6 +12,7 @@ import (
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -67,6 +68,8 @@ func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth
// Execute performs a non-streaming chat completion request.
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := iflowCreds(auth)
if strings.TrimSpace(apiKey) == "" {
err = fmt.Errorf("iflow executor: missing api key")
@@ -76,7 +79,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
@@ -85,17 +88,17 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
if err != nil {
return resp, err
}
body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -154,6 +157,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter.ensurePublished(ctx)
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
@@ -161,6 +166,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := iflowCreds(auth)
if strings.TrimSpace(apiKey) == "" {
err = fmt.Errorf("iflow executor: missing api key")
@@ -170,7 +177,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
@@ -179,23 +186,22 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
if err != nil {
return nil, err
}
body = applyIFlowThinkingConfig(body)
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
@@ -278,11 +284,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
enc, err := tokenizerForModel(req.Model)
enc, err := tokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
}
@@ -520,41 +528,3 @@ func preserveReasoningContentInMessages(body []byte) []byte {
return body
}
// applyIFlowThinkingConfig converts normalized reasoning_effort to model-specific thinking configurations.
// This should be called after NormalizeThinkingConfig has processed the payload.
//
// Model-specific handling:
// - GLM-4.6/4.7: Uses chat_template_kwargs.enable_thinking (boolean) and chat_template_kwargs.clear_thinking=false
// - MiniMax M2/M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
func applyIFlowThinkingConfig(body []byte) []byte {
effort := gjson.GetBytes(body, "reasoning_effort")
if !effort.Exists() {
return body
}
model := strings.ToLower(gjson.GetBytes(body, "model").String())
val := strings.ToLower(strings.TrimSpace(effort.String()))
enableThinking := val != "none" && val != ""
// Remove reasoning_effort as we'll convert to model-specific format
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
body, _ = sjson.DeleteBytes(body, "thinking")
// GLM-4.6/4.7: Use chat_template_kwargs
if strings.HasPrefix(model, "glm-4") {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
if enableThinking {
body, _ = sjson.SetBytes(body, "chat_template_kwargs.clear_thinking", false)
}
return body
}
// MiniMax M2/M2.1: Use reasoning_split
if strings.HasPrefix(model, "minimax-m2") {
body, _ = sjson.SetBytes(body, "reasoning_split", enableThinking)
return body
}
return body
}

View File

@@ -0,0 +1,67 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
func TestIFlowExecutorParseSuffix(t *testing.T) {
tests := []struct {
name string
model string
wantBase string
wantLevel string
}{
{"no suffix", "glm-4", "glm-4", ""},
{"glm with suffix", "glm-4.1-flash(high)", "glm-4.1-flash", "high"},
{"minimax no suffix", "minimax-m2", "minimax-m2", ""},
{"minimax with suffix", "minimax-m2.1(medium)", "minimax-m2.1", "medium"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := thinking.ParseSuffix(tt.model)
if result.ModelName != tt.wantBase {
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
}
})
}
}
func TestPreserveReasoningContentInMessages(t *testing.T) {
tests := []struct {
name string
input []byte
want []byte // nil means output should equal input
}{
{
"non-glm model passthrough",
[]byte(`{"model":"gpt-4","messages":[]}`),
nil,
},
{
"glm model with empty messages",
[]byte(`{"model":"glm-4","messages":[]}`),
nil,
},
{
"glm model preserves existing reasoning_content",
[]byte(`{"model":"glm-4","messages":[{"role":"assistant","content":"hi","reasoning_content":"thinking..."}]}`),
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := preserveReasoningContentInMessages(tt.input)
want := tt.want
if want == nil {
want = tt.input
}
if string(got) != string(want) {
t.Errorf("preserveReasoningContentInMessages() = %s, want %s", got, want)
}
})
}
}

View File

@@ -7,13 +7,16 @@ import (
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/google/uuid"
@@ -53,8 +56,28 @@ const (
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
kiroIDEAgentModeSpec = "spec"
// Socket retry configuration constants (based on kiro2Api reference implementation)
// Maximum number of retry attempts for socket/network errors
kiroSocketMaxRetries = 3
// Base delay between retry attempts (uses exponential backoff: delay * 2^attempt)
kiroSocketBaseRetryDelay = 1 * time.Second
// Maximum delay between retry attempts (cap for exponential backoff)
kiroSocketMaxRetryDelay = 30 * time.Second
// First token timeout for streaming responses (how long to wait for first response)
kiroFirstTokenTimeout = 15 * time.Second
// Streaming read timeout (how long to wait between chunks)
kiroStreamingReadTimeout = 300 * time.Second
)
// retryableHTTPStatusCodes defines HTTP status codes that are considered retryable.
// Based on kiro2Api reference: 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout)
var retryableHTTPStatusCodes = map[int]bool{
502: true, // Bad Gateway - upstream server error
503: true, // Service Unavailable - server temporarily overloaded
504: true, // Gateway Timeout - upstream server timeout
}
// Real-time usage estimation configuration
// These control how often usage updates are sent during streaming
var (
@@ -62,6 +85,241 @@ var (
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
)
// Global FingerprintManager for dynamic User-Agent generation per token
// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
var (
globalFingerprintManager *kiroauth.FingerprintManager
globalFingerprintManagerOnce sync.Once
)
// getGlobalFingerprintManager returns the global FingerprintManager instance
func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = kiroauth.NewFingerprintManager()
log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
})
return globalFingerprintManager
}
// retryConfig holds configuration for socket retry logic.
// Based on kiro2Api Python implementation patterns.
type retryConfig struct {
MaxRetries int // Maximum number of retry attempts
BaseDelay time.Duration // Base delay between retries (exponential backoff)
MaxDelay time.Duration // Maximum delay cap
RetryableErrors []string // List of retryable error patterns
RetryableStatus map[int]bool // HTTP status codes to retry
FirstTokenTmout time.Duration // Timeout for first token in streaming
StreamReadTmout time.Duration // Timeout between stream chunks
}
// defaultRetryConfig returns the default retry configuration for Kiro socket operations.
func defaultRetryConfig() retryConfig {
return retryConfig{
MaxRetries: kiroSocketMaxRetries,
BaseDelay: kiroSocketBaseRetryDelay,
MaxDelay: kiroSocketMaxRetryDelay,
RetryableStatus: retryableHTTPStatusCodes,
RetryableErrors: []string{
"connection reset",
"connection refused",
"broken pipe",
"EOF",
"timeout",
"temporary failure",
"no such host",
"network is unreachable",
"i/o timeout",
},
FirstTokenTmout: kiroFirstTokenTimeout,
StreamReadTmout: kiroStreamingReadTimeout,
}
}
// isRetryableError checks if an error is retryable based on error type and message.
// Returns true for network timeouts, connection resets, and temporary failures.
// Based on kiro2Api's retry logic patterns.
func isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for context cancellation - not retryable
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
// Check for net.Error (timeout, temporary)
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
log.Debugf("kiro: isRetryableError: network timeout detected")
return true
}
// Note: Temporary() is deprecated but still useful for some error types
}
// Check for specific syscall errors (connection reset, broken pipe, etc.)
var syscallErr syscall.Errno
if errors.As(err, &syscallErr) {
switch syscallErr {
case syscall.ECONNRESET: // Connection reset by peer
log.Debugf("kiro: isRetryableError: ECONNRESET detected")
return true
case syscall.ECONNREFUSED: // Connection refused
log.Debugf("kiro: isRetryableError: ECONNREFUSED detected")
return true
case syscall.EPIPE: // Broken pipe
log.Debugf("kiro: isRetryableError: EPIPE (broken pipe) detected")
return true
case syscall.ETIMEDOUT: // Connection timed out
log.Debugf("kiro: isRetryableError: ETIMEDOUT detected")
return true
case syscall.ENETUNREACH: // Network is unreachable
log.Debugf("kiro: isRetryableError: ENETUNREACH detected")
return true
case syscall.EHOSTUNREACH: // No route to host
log.Debugf("kiro: isRetryableError: EHOSTUNREACH detected")
return true
}
}
// Check for net.OpError wrapping other errors
var opErr *net.OpError
if errors.As(err, &opErr) {
log.Debugf("kiro: isRetryableError: net.OpError detected, op=%s", opErr.Op)
// Recursively check the wrapped error
if opErr.Err != nil {
return isRetryableError(opErr.Err)
}
return true
}
// Check error message for retryable patterns
errMsg := strings.ToLower(err.Error())
cfg := defaultRetryConfig()
for _, pattern := range cfg.RetryableErrors {
if strings.Contains(errMsg, pattern) {
log.Debugf("kiro: isRetryableError: pattern '%s' matched in error: %s", pattern, errMsg)
return true
}
}
// Check for EOF which may indicate connection was closed
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("kiro: isRetryableError: EOF/UnexpectedEOF detected")
return true
}
return false
}
// isRetryableHTTPStatus checks if an HTTP status code is retryable.
// Based on kiro2Api: 502, 503, 504 are retryable server errors.
func isRetryableHTTPStatus(statusCode int) bool {
return retryableHTTPStatusCodes[statusCode]
}
// calculateRetryDelay calculates the delay for the next retry attempt using exponential backoff.
// delay = min(baseDelay * 2^attempt, maxDelay)
// Adds ±30% jitter to prevent thundering herd.
func calculateRetryDelay(attempt int, cfg retryConfig) time.Duration {
return kiroauth.ExponentialBackoffWithJitter(attempt, cfg.BaseDelay, cfg.MaxDelay)
}
// logRetryAttempt logs a retry attempt with relevant context.
func logRetryAttempt(attempt, maxRetries int, reason string, delay time.Duration, endpoint string) {
log.Warnf("kiro: retry attempt %d/%d for %s, waiting %v before next attempt (endpoint: %s)",
attempt+1, maxRetries, reason, delay, endpoint)
}
// kiroHTTPClientPool provides a shared HTTP client with connection pooling for Kiro API.
// This reduces connection overhead and improves performance for concurrent requests.
// Based on kiro2Api's connection pooling pattern.
var (
kiroHTTPClientPool *http.Client
kiroHTTPClientPoolOnce sync.Once
)
// getKiroPooledHTTPClient returns a shared HTTP client with optimized connection pooling.
// The client is lazily initialized on first use and reused across requests.
// This is especially beneficial for:
// - Reducing TCP handshake overhead
// - Enabling HTTP/2 multiplexing
// - Better handling of keep-alive connections
func getKiroPooledHTTPClient() *http.Client {
kiroHTTPClientPoolOnce.Do(func() {
transport := &http.Transport{
// Connection pool settings
MaxIdleConns: 100, // Max idle connections across all hosts
MaxIdleConnsPerHost: 20, // Max idle connections per host
MaxConnsPerHost: 50, // Max total connections per host
IdleConnTimeout: 90 * time.Second, // How long idle connections stay in pool
// Timeouts for connection establishment
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, // TCP connection timeout
KeepAlive: 30 * time.Second, // TCP keep-alive interval
}).DialContext,
// TLS handshake timeout
TLSHandshakeTimeout: 10 * time.Second,
// Response header timeout
ResponseHeaderTimeout: 30 * time.Second,
// Expect 100-continue timeout
ExpectContinueTimeout: 1 * time.Second,
// Enable HTTP/2 when available
ForceAttemptHTTP2: true,
}
kiroHTTPClientPool = &http.Client{
Transport: transport,
// No global timeout - let individual requests set their own timeouts via context
}
log.Debugf("kiro: initialized pooled HTTP client (MaxIdleConns=%d, MaxIdleConnsPerHost=%d, MaxConnsPerHost=%d)",
transport.MaxIdleConns, transport.MaxIdleConnsPerHost, transport.MaxConnsPerHost)
})
return kiroHTTPClientPool
}
// newKiroHTTPClientWithPooling creates an HTTP client that uses connection pooling when appropriate.
// It respects proxy configuration from auth or config, falling back to the pooled client.
// This provides the best of both worlds: custom proxy support + connection reuse.
func newKiroHTTPClientWithPooling(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
// Check if a proxy is configured - if so, we need a custom client
var proxyURL string
if auth != nil {
proxyURL = strings.TrimSpace(auth.ProxyURL)
}
if proxyURL == "" && cfg != nil {
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// If proxy is configured, use the existing proxy-aware client (doesn't pool)
if proxyURL != "" {
log.Debugf("kiro: using proxy-aware HTTP client (proxy=%s)", proxyURL)
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
// No proxy - use pooled client for better performance
pooledClient := getKiroPooledHTTPClient()
// If timeout is specified, we need to wrap the pooled transport with timeout
if timeout > 0 {
return &http.Client{
Transport: pooledClient.Transport,
Timeout: timeout,
}
}
return pooledClient
}
// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values.
// This solves the "triple mismatch" problem where different endpoints require matching
// Origin and X-Amz-Target header values.
@@ -216,6 +474,29 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
// Identifier returns the unique identifier for this executor.
func (e *KiroExecutor) Identifier() string { return "kiro" }
// applyDynamicFingerprint applies token-specific fingerprint headers to the request
// For IDC auth, uses dynamic fingerprint-based User-Agent
// For other auth types, uses static Amazon Q CLI style headers
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
if isIDCAuth(auth) {
// Get token-specific fingerprint for dynamic UA generation
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
// Use fingerprint-generated dynamic User-Agent
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
} else {
// Use static Amazon Q CLI style headers for non-IDC auth
req.Header.Set("User-Agent", kiroUserAgent)
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
}
// PrepareRequest prepares the HTTP request before execution.
func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
@@ -225,14 +506,10 @@ func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
if strings.TrimSpace(accessToken) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
if isIDCAuth(auth) {
req.Header.Set("User-Agent", kiroIDEUserAgent)
req.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
} else {
req.Header.Set("User-Agent", kiroUserAgent)
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
// Apply dynamic fingerprint-based headers
applyDynamicFingerprint(req, auth)
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
req.Header.Set("Authorization", "Bearer "+accessToken)
@@ -256,10 +533,23 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
return nil, errPrepare
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// getTokenKey returns a unique key for rate limiting based on auth credentials.
// Uses auth ID if available, otherwise falls back to a hash of the access token.
func getTokenKey(auth *cliproxyauth.Auth) string {
if auth != nil && auth.ID != "" {
return auth.ID
}
accessToken, _ := kiroCredentials(auth)
if len(accessToken) > 16 {
return accessToken[:16]
}
return accessToken
}
// Execute sends the request to Kiro API and returns the response.
// Supports automatic token refresh on 401/403 errors.
func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
@@ -268,23 +558,53 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, fmt.Errorf("kiro: access token not found in auth")
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
// Check if token is in cooldown period
if cooldownMgr.IsInCooldown(tokenKey) {
remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
reason := cooldownMgr.GetCooldownReason(tokenKey)
log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
return resp, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
}
// Wait for rate limiter before proceeding
log.Debugf("kiro: waiting for rate limiter for token %s", tokenKey)
rateLimiter.WaitForToken(tokenKey)
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
// Check if token is expired before making request
if e.isTokenExpired(accessToken) {
log.Infof("kiro: access token expired, attempting refresh before request")
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
log.Infof("kiro: access token expired, attempting recovery")
// 方案 B: 先尝试从文件重新加载 token后台刷新器可能已更新文件
reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
if reloadErr == nil && reloadedAuth != nil {
// 文件中有更新的 token使用它
auth = reloadedAuth
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before request")
log.Infof("kiro: recovered token from file (background refresh), expires_at: %v", auth.Metadata["expires_at"])
} else {
// 文件中的 token 也过期了,执行主动刷新
log.Debugf("kiro: file reload failed (%v), attempting active refresh", reloadErr)
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before request")
}
}
}
@@ -300,7 +620,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Execute with retry on 401/403 and 429 (quota exhausted)
// Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint
resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly)
resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
return resp, err
}
@@ -309,9 +629,12 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
// Also supports multi-endpoint fallback similar to Antigravity implementation.
func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) {
// tokenKey is used for rate limiting and cooldown tracking.
func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (cliproxyexecutor.Response, error) {
var resp cliproxyexecutor.Response
maxRetries := 2 // Allow retries for token refresh + endpoint fallback
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
endpointConfigs := getKiroEndpointConfigs(auth)
var last429Err error
@@ -329,6 +652,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
for attempt := 0; attempt <= maxRetries; attempt++ {
// Apply human-like delay before first request (not on retries)
// This mimics natural user behavior patterns
if attempt == 0 && endpointIdx == 0 {
kiroauth.ApplyHumanLikeDelay()
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
if err != nil {
return resp, err
@@ -339,18 +668,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
// Use different headers based on auth type
// IDC auth uses Kiro IDE style headers (from kiro2api)
// Other auth types use Amazon Q CLI style headers
if isIDCAuth(auth) {
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
} else {
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
// Apply dynamic fingerprint-based headers
applyDynamicFingerprint(httpReq, auth)
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
@@ -381,10 +701,34 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second)
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 120*time.Second)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
// Check for context cancellation first - client disconnected, not a server error
// Use 499 (Client Closed Request - nginx convention) instead of 500
if errors.Is(err, context.Canceled) {
log.Debugf("kiro: request canceled by client (context.Canceled)")
return resp, statusErr{code: 499, msg: "client canceled request"}
}
// Check for context deadline exceeded - request timed out
// Return 504 Gateway Timeout instead of 500
if errors.Is(err, context.DeadlineExceeded) {
log.Debugf("kiro: request timed out (context.DeadlineExceeded)")
return resp, statusErr{code: http.StatusGatewayTimeout, msg: "upstream request timed out"}
}
recordAPIResponseError(ctx, e.cfg, err)
// Enhanced socket retry: Check if error is retryable (network timeout, connection reset, etc.)
retryCfg := defaultRetryConfig()
if isRetryableError(err) && attempt < retryCfg.MaxRetries {
delay := calculateRetryDelay(attempt, retryCfg)
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("socket error: %v", err), delay, endpointConfig.Name)
time.Sleep(delay)
continue
}
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
@@ -396,6 +740,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
_ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody)
// Record failure and set cooldown for 429
rateLimiter.MarkTokenFailed(tokenKey)
cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
log.Warnf("kiro: rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
// Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
@@ -407,13 +757,21 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
}
// Handle 5xx server errors with exponential backoff retry
// Enhanced: Use retryConfig for consistent retry behavior
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
respBody, _ := io.ReadAll(httpResp.Body)
_ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody)
if attempt < maxRetries {
// Exponential backoff: 1s, 2s, 4s... (max 30s)
retryCfg := defaultRetryConfig()
// Check if this specific 5xx code is retryable (502, 503, 504)
if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
delay := calculateRetryDelay(attempt, retryCfg)
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
time.Sleep(delay)
continue
} else if attempt < maxRetries {
// Fallback for other 5xx errors (500, 501, etc.)
backoff := time.Duration(1<<attempt) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
@@ -487,7 +845,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
// Check for SUSPENDED status - return immediately without retry
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
log.Errorf("kiro: account is suspended, cannot proceed")
// Set long cooldown for suspended accounts
rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
log.Errorf("kiro: account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
}
@@ -576,6 +937,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
appendAPIResponseChunk(ctx, e.cfg, []byte(content))
reporter.publish(ctx, usageInfo)
// Record success for rate limiting
rateLimiter.MarkTokenSuccess(tokenKey)
log.Debugf("kiro: request successful, token %s marked as success", tokenKey)
// Build response in Claude format for Kiro translator
// stopReason is extracted from upstream response by parseEventStream
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
@@ -603,23 +968,53 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, fmt.Errorf("kiro: access token not found in auth")
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
// Check if token is in cooldown period
if cooldownMgr.IsInCooldown(tokenKey) {
remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
reason := cooldownMgr.GetCooldownReason(tokenKey)
log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
}
// Wait for rate limiter before proceeding
log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey)
rateLimiter.WaitForToken(tokenKey)
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
// Check if token is expired before making request
if e.isTokenExpired(accessToken) {
log.Infof("kiro: access token expired, attempting refresh before stream request")
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
log.Infof("kiro: access token expired, attempting recovery before stream request")
// 方案 B: 先尝试从文件重新加载 token后台刷新器可能已更新文件
reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
if reloadErr == nil && reloadedAuth != nil {
// 文件中有更新的 token使用它
auth = reloadedAuth
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before stream request")
log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"])
} else {
// 文件中的 token 也过期了,执行主动刷新
log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr)
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before stream request")
}
}
}
@@ -635,7 +1030,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
// Execute stream with retry on 401/403 and 429 (quota exhausted)
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly)
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
}
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
@@ -643,8 +1038,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
// Also supports multi-endpoint fallback similar to Antigravity implementation.
func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) {
// tokenKey is used for rate limiting and cooldown tracking.
func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) {
maxRetries := 2 // Allow retries for token refresh + endpoint fallback
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
endpointConfigs := getKiroEndpointConfigs(auth)
var last429Err error
@@ -662,6 +1060,13 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
for attempt := 0; attempt <= maxRetries; attempt++ {
// Apply human-like delay before first streaming request (not on retries)
// This mimics natural user behavior patterns
// Note: Delay is NOT applied during streaming response - only before initial request
if attempt == 0 && endpointIdx == 0 {
kiroauth.ApplyHumanLikeDelay()
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
if err != nil {
return nil, err
@@ -672,18 +1077,9 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
// Use different headers based on auth type
// IDC auth uses Kiro IDE style headers (from kiro2api)
// Other auth types use Amazon Q CLI style headers
if isIDCAuth(auth) {
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
} else {
httpReq.Header.Set("User-Agent", kiroUserAgent)
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
}
// Apply dynamic fingerprint-based headers
applyDynamicFingerprint(httpReq, auth)
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
@@ -714,10 +1110,20 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
// Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.)
retryCfg := defaultRetryConfig()
if isRetryableError(err) && attempt < retryCfg.MaxRetries {
delay := calculateRetryDelay(attempt, retryCfg)
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name)
time.Sleep(delay)
continue
}
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
@@ -729,6 +1135,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
_ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody)
// Record failure and set cooldown for 429
rateLimiter.MarkTokenFailed(tokenKey)
cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
// Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
@@ -740,13 +1152,21 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
}
// Handle 5xx server errors with exponential backoff retry
// Enhanced: Use retryConfig for consistent retry behavior
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
respBody, _ := io.ReadAll(httpResp.Body)
_ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody)
if attempt < maxRetries {
// Exponential backoff: 1s, 2s, 4s... (max 30s)
retryCfg := defaultRetryConfig()
// Check if this specific 5xx code is retryable (502, 503, 504)
if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
delay := calculateRetryDelay(attempt, retryCfg)
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
time.Sleep(delay)
continue
} else if attempt < maxRetries {
// Fallback for other 5xx errors (500, 501, etc.)
backoff := time.Duration(1<<attempt) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
@@ -833,7 +1253,10 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
// Check for SUSPENDED status - return immediately without retry
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
log.Errorf("kiro: account is suspended, cannot proceed")
// Set long cooldown for suspended accounts
rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
}
@@ -883,6 +1306,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
out := make(chan cliproxyexecutor.StreamChunk)
// Record success immediately since connection was established successfully
// Streaming errors will be handled separately
rateLimiter.MarkTokenSuccess(tokenKey)
log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey)
go func(resp *http.Response, thinkingEnabled bool) {
defer close(out)
defer func() {
@@ -3109,14 +3537,14 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
// Also check if expires_at is now in the future with sufficient buffer
if expiresAt, ok := auth.Metadata["expires_at"].(string); ok {
if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil {
// If token expires more than 5 minutes from now, it's still valid
if time.Until(expTime) > 5*time.Minute {
// If token expires more than 20 minutes from now, it's still valid
if time.Until(expTime) > 20*time.Minute {
log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime))
// CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks
// Without this, shouldRefresh() will return true again in 5 seconds
// Without this, shouldRefresh() will return true again in 30 seconds
updated := auth.Clone()
// Set next refresh to 5 minutes before expiry, or at least 30 seconds from now
nextRefresh := expTime.Add(-5 * time.Minute)
// Set next refresh to 20 minutes before expiry, or at least 30 seconds from now
nextRefresh := expTime.Add(-20 * time.Minute)
minNextRefresh := time.Now().Add(30 * time.Second)
if nextRefresh.Before(minNextRefresh) {
nextRefresh = minNextRefresh
@@ -3213,6 +3641,13 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
if tokenData.ClientSecret != "" {
updated.Metadata["client_secret"] = tokenData.ClientSecret
}
// Preserve region and start_url for IDC token refresh
if tokenData.Region != "" {
updated.Metadata["region"] = tokenData.Region
}
if tokenData.StartURL != "" {
updated.Metadata["start_url"] = tokenData.StartURL
}
if updated.Attributes == nil {
updated.Attributes = make(map[string]string)
@@ -3222,9 +3657,9 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
updated.Attributes["profile_arn"] = tokenData.ProfileArn
}
// NextRefreshAfter is aligned with RefreshLead (5min)
// NextRefreshAfter is aligned with RefreshLead (20min)
if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil {
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
}
log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt)
@@ -3278,6 +3713,121 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
return nil
}
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("kiro executor: cannot reload nil auth")
}
// 确定文件路径
var authPath string
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
authPath = p
}
}
if authPath == "" {
fileName := strings.TrimSpace(auth.FileName)
if fileName == "" {
return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload")
}
if filepath.IsAbs(fileName) {
authPath = fileName
} else if e.cfg != nil && e.cfg.AuthDir != "" {
authPath = filepath.Join(e.cfg.AuthDir, fileName)
} else {
return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload")
}
}
// 读取文件
raw, err := os.ReadFile(authPath)
if err != nil {
return nil, fmt.Errorf("kiro executor: failed to read auth file %s: %w", authPath, err)
}
// 解析 JSON
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return nil, fmt.Errorf("kiro executor: failed to parse auth file %s: %w", authPath, err)
}
// 检查文件中的 token 是否比内存中的更新
fileExpiresAt, _ := metadata["expires_at"].(string)
fileAccessToken, _ := metadata["access_token"].(string)
memExpiresAt, _ := auth.Metadata["expires_at"].(string)
memAccessToken, _ := auth.Metadata["access_token"].(string)
// 文件中必须有有效的 access_token
if fileAccessToken == "" {
return nil, fmt.Errorf("kiro executor: auth file has no access_token field")
}
// 如果有 expires_at检查是否过期
if fileExpiresAt != "" {
fileExpTime, parseErr := time.Parse(time.RFC3339, fileExpiresAt)
if parseErr == nil {
// 如果文件中的 token 也已过期,不使用它
if time.Now().After(fileExpTime) {
log.Debugf("kiro executor: file token also expired at %s, not using", fileExpiresAt)
return nil, fmt.Errorf("kiro executor: file token also expired")
}
}
}
// 判断文件中的 token 是否比内存中的更新
// 条件1: access_token 不同(说明已刷新)
// 条件2: expires_at 更新(说明已刷新)
isNewer := false
// 优先检查 access_token 是否变化
if fileAccessToken != memAccessToken {
isNewer = true
log.Debugf("kiro executor: file access_token differs from memory, using file token")
}
// 如果 access_token 相同,检查 expires_at
if !isNewer && fileExpiresAt != "" && memExpiresAt != "" {
fileExpTime, fileParseErr := time.Parse(time.RFC3339, fileExpiresAt)
memExpTime, memParseErr := time.Parse(time.RFC3339, memExpiresAt)
if fileParseErr == nil && memParseErr == nil && fileExpTime.After(memExpTime) {
isNewer = true
log.Debugf("kiro executor: file expires_at (%s) is newer than memory (%s)", fileExpiresAt, memExpiresAt)
}
}
// 如果文件中没有 expires_at 但 access_token 相同,无法判断是否更新
if !isNewer && fileExpiresAt == "" && fileAccessToken == memAccessToken {
return nil, fmt.Errorf("kiro executor: cannot determine if file token is newer (no expires_at, same access_token)")
}
if !isNewer {
log.Debugf("kiro executor: file token not newer than memory token")
return nil, fmt.Errorf("kiro executor: file token not newer")
}
// 创建更新后的 auth 对象
updated := auth.Clone()
updated.Metadata = metadata
updated.UpdatedAt = time.Now()
// 同步更新 Attributes
if updated.Attributes == nil {
updated.Attributes = make(map[string]string)
}
if accessToken, ok := metadata["access_token"].(string); ok {
updated.Attributes["access_token"] = accessToken
}
if profileArn, ok := metadata["profile_arn"].(string); ok {
updated.Attributes["profile_arn"] = profileArn
}
log.Infof("kiro executor: reloaded auth from file %s, new expires_at: %s", authPath, fileExpiresAt)
return updated, nil
}
// isTokenExpired checks if a JWT access token has expired.
// Returns true if the token is expired or cannot be parsed.
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -69,7 +70,9 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
@@ -85,18 +88,13 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return resp, errValidate
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
@@ -168,7 +166,9 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
@@ -176,24 +176,20 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return nil, err
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
modelOverride := e.resolveUpstreamModel(req.Model, auth)
if modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated, originalTranslated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
return nil, errValidate
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
@@ -293,14 +289,17 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
modelForCounting := req.Model
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
modelForCounting = modelOverride
modelForCounting := baseModel
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizerForModel(modelForCounting)
@@ -336,53 +335,6 @@ func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (base
return
}
func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
if alias == "" || auth == nil || e.cfg == nil {
return ""
}
compat := e.resolveCompatConfig(auth)
if compat == nil {
return ""
}
for i := range compat.Models {
model := compat.Models[i]
if model.Alias != "" {
if strings.EqualFold(model.Alias, alias) {
if model.Name != "" {
return model.Name
}
return alias
}
continue
}
if strings.EqualFold(model.Name, alias) {
return model.Name
}
}
return ""
}
func (e *OpenAICompatExecutor) allowCompatReasoningEffort(model string, auth *cliproxyauth.Auth) bool {
trimmed := strings.TrimSpace(model)
if trimmed == "" || e == nil || e.cfg == nil {
return false
}
compat := e.resolveCompatConfig(auth)
if compat == nil || len(compat.Models) == 0 {
return false
}
for i := range compat.Models {
entry := compat.Models[i]
if strings.EqualFold(strings.TrimSpace(entry.Alias), trimmed) {
return true
}
if strings.EqualFold(strings.TrimSpace(entry.Name), trimmed) {
return true
}
}
return false
}
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
if auth == nil || e.cfg == nil {
return nil

View File

@@ -1,109 +1,14 @@
package executor
import (
"fmt"
"net/http"
"encoding/json"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
}
// ApplyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., (high), (8192))
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
func ApplyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
// Use the alias from metadata if available, as it's registered in the global registry
// with thinking metadata; the upstream model name may not be registered.
lookupModel := util.ResolveOriginalModel(model, metadata)
// Determine which model to use for thinking support check.
// If the alias (lookupModel) is not in the registry, fall back to the upstream model.
thinkingModel := lookupModel
if !util.ModelSupportsThinking(lookupModel) && util.ModelSupportsThinking(model) {
thinkingModel = model
}
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(thinkingModel, metadata)
if !ok || (budgetOverride == nil && includeOverride == nil) {
return payload
}
if !util.ModelSupportsThinking(thinkingModel) {
return payload
}
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(thinkingModel, *budgetOverride)
budgetOverride = &norm
}
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
}
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
// overwriting caller-provided values to honor suffix/default metadata priority.
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
if len(metadata) == 0 {
return payload
}
if field == "" {
return payload
}
baseModel := util.ResolveOriginalModel(model, metadata)
if baseModel == "" {
baseModel = model
}
if !util.ModelSupportsThinking(baseModel) && !allowCompat {
return payload
}
if effort, ok := util.ReasoningEffortFromMetadata(metadata); ok && effort != "" {
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
return updated
}
}
}
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
return updated
}
}
}
}
return payload
}
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked
@@ -113,13 +18,14 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.Override) == 0 {
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 {
return payload
}
model = strings.TrimSpace(model)
if model == "" {
return payload
}
candidates := payloadModelCandidates(cfg, model, protocol)
out := payload
source := original
if len(source) == 0 {
@@ -129,7 +35,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadRuleMatchesModel(rule, model, protocol) {
if !payloadRuleMatchesModels(rule, protocol, candidates) {
continue
}
for path, value := range rule.Params {
@@ -151,10 +57,39 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadRuleMatchesModels(rule, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadRuleMatchesModel(rule, model, protocol) {
if !payloadRuleMatchesModels(rule, protocol, candidates) {
continue
}
for path, value := range rule.Params {
@@ -169,9 +104,43 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadRuleMatchesModels(rule, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
return out
}
func payloadRuleMatchesModels(rule *config.PayloadRule, protocol string, models []string) bool {
if rule == nil || len(models) == 0 {
return false
}
for _, model := range models {
if payloadRuleMatchesModel(rule, model, protocol) {
return true
}
}
return false
}
func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) bool {
if rule == nil {
return false
@@ -194,6 +163,65 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
return false
}
func payloadModelCandidates(cfg *config.Config, model, protocol string) []string {
model = strings.TrimSpace(model)
if model == "" {
return nil
}
candidates := []string{model}
if cfg == nil {
return candidates
}
aliases := payloadModelAliases(cfg, model, protocol)
if len(aliases) == 0 {
return candidates
}
seen := map[string]struct{}{strings.ToLower(model): struct{}{}}
for _, alias := range aliases {
alias = strings.TrimSpace(alias)
if alias == "" {
continue
}
key := strings.ToLower(alias)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
candidates = append(candidates, alias)
}
return candidates
}
func payloadModelAliases(cfg *config.Config, model, protocol string) []string {
if cfg == nil {
return nil
}
model = strings.TrimSpace(model)
if model == "" {
return nil
}
channel := strings.ToLower(strings.TrimSpace(protocol))
if channel == "" {
return nil
}
entries := cfg.OAuthModelAlias[channel]
if len(entries) == 0 {
return nil
}
aliases := make([]string, 0, 2)
for _, entry := range entries {
if !strings.EqualFold(strings.TrimSpace(entry.Name), model) {
continue
}
alias := strings.TrimSpace(entry.Alias)
if alias == "" {
continue
}
aliases = append(aliases, alias)
}
return aliases
}
// buildPayloadPath combines an optional root path with a relative parameter path.
// When root is empty, the parameter path is used as-is. When root is non-empty,
// the parameter path is treated as relative to root.
@@ -212,6 +240,24 @@ func buildPayloadPath(root, path string) string {
return r + "." + p
}
func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false
}
switch typed := value.(type) {
case string:
return []byte(typed), true
case []byte:
return typed, true
default:
raw, errMarshal := json.Marshal(typed)
if errMarshal != nil {
return nil, false
}
return raw, true
}
}
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
// Examples:
//
@@ -256,102 +302,3 @@ func matchModelPattern(pattern, model string) bool {
}
return pi == len(pattern)
}
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
// based on model capabilities. For models without thinking support, it strips
// reasoning fields. For models with level-based thinking, it validates and
// normalizes the reasoning effort level. For models with numeric budget thinking,
// it strips the effort string fields.
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
if len(payload) == 0 || model == "" {
return payload
}
if !util.ModelSupportsThinking(model) {
if allowCompat {
return payload
}
return StripThinkingFields(payload, false)
}
if util.ModelUsesThinkingLevels(model) {
return NormalizeReasoningEffortLevel(payload, model)
}
// Model supports thinking but uses numeric budgets, not levels.
// Strip effort string fields since they are not applicable.
return StripThinkingFields(payload, true)
}
// StripThinkingFields removes thinking-related fields from the payload for
// models that do not support thinking. If effortOnly is true, only removes
// effort string fields (for models using numeric budgets).
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
fieldsToRemove := []string{
"reasoning_effort",
"reasoning.effort",
}
if !effortOnly {
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
}
out := payload
for _, field := range fieldsToRemove {
if gjson.GetBytes(out, field).Exists() {
out, _ = sjson.DeleteBytes(out, field)
}
}
return out
}
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
// or reasoning.effort field for level-based thinking models.
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
out := payload
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
if normalized, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); ok {
out, _ = sjson.SetBytes(out, "reasoning_effort", normalized)
}
}
if effort := gjson.GetBytes(out, "reasoning.effort"); effort.Exists() {
if normalized, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); ok {
out, _ = sjson.SetBytes(out, "reasoning.effort", normalized)
}
}
return out
}
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
// downgrading requests.
func ValidateThinkingConfig(payload []byte, model string) error {
if len(payload) == 0 || model == "" {
return nil
}
if !util.ModelSupportsThinking(model) || !util.ModelUsesThinkingLevels(model) {
return nil
}
levels := util.GetModelThinkingLevels(model)
checkField := func(path string) error {
if effort := gjson.GetBytes(payload, path); effort.Exists() {
if _, ok := util.NormalizeReasoningEffortLevel(model, effort.String()); !ok {
return statusErr{
code: http.StatusBadRequest,
msg: fmt.Sprintf("unsupported reasoning effort level %q for model %s (supported: %s)", effort.String(), model, strings.Join(levels, ", ")),
}
}
}
return nil
}
if err := checkField("reasoning_effort"); err != nil {
return err
}
if err := checkField("reasoning.effort"); err != nil {
return err
}
return nil
}

View File

@@ -12,6 +12,7 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -65,12 +66,14 @@ func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
}
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, baseURL := qwenCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
if baseURL == "" {
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
@@ -79,15 +82,16 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return resp, errValidate
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -140,18 +144,22 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
token, baseURL := qwenCreds(auth)
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
if baseURL == "" {
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
@@ -160,15 +168,15 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
body, _ = sjson.SetBytes(body, "model", req.Model)
body = NormalizeThinkingConfig(body, req.Model, false)
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
return nil, errValidate
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
toolsResult := gjson.GetBytes(body, "tools")
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
// This will have no real consequences. It's just to scare Qwen3.
@@ -176,7 +184,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
@@ -256,13 +264,15 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
modelName := gjson.GetBytes(body, "model").String()
if strings.TrimSpace(modelName) == "" {
modelName = req.Model
modelName = baseModel
}
enc, err := tokenizerForModel(modelName)

View File

@@ -0,0 +1,30 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
func TestQwenExecutorParseSuffix(t *testing.T) {
tests := []struct {
name string
model string
wantBase string
wantLevel string
}{
{"no suffix", "qwen-max", "qwen-max", ""},
{"with level suffix", "qwen-max(high)", "qwen-max", "high"},
{"with budget suffix", "qwen-max(16384)", "qwen-max", "16384"},
{"complex model name", "qwen-plus-latest(medium)", "qwen-plus-latest", "medium"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := thinking.ParseSuffix(tt.model)
if result.ModelName != tt.wantBase {
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
}
})
}
}

View File

@@ -0,0 +1,11 @@
package executor
import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/codex"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
)

View File

@@ -236,6 +236,44 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {

487
internal/thinking/apply.go Normal file
View File

@@ -0,0 +1,487 @@
// Package thinking provides unified thinking configuration processing.
package thinking
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// providerAppliers maps provider names to their ProviderApplier implementations.
var providerAppliers = map[string]ProviderApplier{
"gemini": nil,
"gemini-cli": nil,
"claude": nil,
"openai": nil,
"codex": nil,
"iflow": nil,
"antigravity": nil,
}
// GetProviderApplier returns the ProviderApplier for the given provider name.
// Returns nil if the provider is not registered.
func GetProviderApplier(provider string) ProviderApplier {
return providerAppliers[provider]
}
// RegisterProvider registers a provider applier by name.
func RegisterProvider(name string, applier ProviderApplier) {
providerAppliers[name] = applier
}
// IsUserDefinedModel reports whether the model is a user-defined model that should
// have thinking configuration passed through without validation.
//
// User-defined models are configured via config file's models[] array
// (e.g., openai-compatibility.*.models[], *-api-key.models[]). These models
// are marked with UserDefined=true at registration time.
//
// User-defined models should have their thinking configuration applied directly,
// letting the upstream service validate the configuration.
func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
if modelInfo == nil {
return true
}
return modelInfo.UserDefined
}
// ApplyThinking applies thinking configuration to a request body.
//
// This is the unified entry point for all providers. It follows the processing
// order defined in FR25: route check → model capability query → config extraction
// → validation → application.
//
// Suffix Priority: When the model name includes a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
// the suffix configuration takes priority over any thinking parameters in the request body.
// This enables users to override thinking settings via the model name without modifying their
// request payload.
//
// Parameters:
// - body: Original request body JSON
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
// - fromFormat: Source request format (e.g., openai, codex, gemini)
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
//
// Returns:
// - Modified request body JSON with thinking configuration applied
// - Error if validation fails (ThinkingError). On error, the original body
// is returned (not nil) to enable defensive programming patterns.
//
// Passthrough behavior (returns original body without error):
// - Unknown provider (not in providerAppliers map)
// - modelInfo.Thinking is nil (model doesn't support thinking)
//
// Note: Unknown models (modelInfo is nil) are treated as user-defined models: we skip
// validation and still apply the thinking config so the upstream can validate it.
//
// Example:
//
// // With suffix - suffix config takes priority
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
//
// // Without suffix - uses body config
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
if providerKey == "" {
providerKey = providerFormat
}
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
if fromFormat == "" {
fromFormat = providerFormat
}
// 1. Route check: Get provider applier
applier := GetProviderApplier(providerFormat)
if applier == nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": model,
}).Debug("thinking: unknown provider, passthrough |")
return body, nil
}
// 2. Parse suffix and get modelInfo
suffixResult := ParseSuffix(model)
baseModel := suffixResult.ModelName
// Use provider-specific lookup to handle capability differences across providers.
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
// 3. Model capability check
// Unknown models are treated as user-defined so thinking config can still be applied.
// The upstream service is responsible for validating the configuration.
if IsUserDefinedModel(modelInfo) {
return applyUserDefinedModel(body, modelInfo, fromFormat, providerFormat, suffixResult)
}
if modelInfo.Thinking == nil {
config := extractThinkingConfig(body, providerFormat)
if hasThinkingConfig(config) {
log.WithFields(log.Fields{
"model": baseModel,
"provider": providerFormat,
}).Debug("thinking: model does not support thinking, stripping config |")
return StripThinkingConfig(body, providerFormat), nil
}
log.WithFields(log.Fields{
"provider": providerFormat,
"model": baseModel,
}).Debug("thinking: model does not support thinking, passthrough |")
return body, nil
}
// 4. Get config: suffix priority over body
var config ThinkingConfig
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model)
log.WithFields(log.Fields{
"provider": providerFormat,
"model": model,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: config from model suffix |")
} else {
config = extractThinkingConfig(body, providerFormat)
if hasThinkingConfig(config) {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: original config from request |")
}
}
if !hasThinkingConfig(config) {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
}).Debug("thinking: no config found, passthrough |")
return body, nil
}
// 5. Validate and normalize configuration
validated, err := ValidateConfig(config, modelInfo, fromFormat, providerFormat, suffixResult.HasSuffix)
if err != nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"error": err.Error(),
}).Warn("thinking: validation failed |")
// Return original body on validation failure (defensive programming).
// This ensures callers who ignore the error won't receive nil body.
// The upstream service will decide how to handle the unmodified request.
return body, err
}
// Defensive check: ValidateConfig should never return (nil, nil)
if validated == nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
}).Warn("thinking: ValidateConfig returned nil config without error, passthrough |")
return body, nil
}
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"mode": validated.Mode,
"budget": validated.Budget,
"level": validated.Level,
}).Debug("thinking: processed config to apply |")
// 6. Apply configuration using provider-specific applier
return applier.Apply(body, *validated, modelInfo)
}
// parseSuffixToConfig converts a raw suffix string to ThinkingConfig.
//
// Parsing priority:
// 1. Special values: "none" → ModeNone, "auto"/"-1" → ModeAuto
// 2. Level names: "minimal", "low", "medium", "high", "xhigh" → ModeLevel
// 3. Numeric values: positive integers → ModeBudget, 0 → ModeNone
//
// If none of the above match, returns empty ThinkingConfig (treated as no config).
func parseSuffixToConfig(rawSuffix, provider, model string) ThinkingConfig {
// 1. Try special values first (none, auto, -1)
if mode, ok := ParseSpecialSuffix(rawSuffix); ok {
switch mode {
case ModeNone:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case ModeAuto:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
}
}
// 2. Try level parsing (minimal, low, medium, high, xhigh)
if level, ok := ParseLevelSuffix(rawSuffix); ok {
return ThinkingConfig{Mode: ModeLevel, Level: level}
}
// 3. Try numeric parsing
if budget, ok := ParseNumericSuffix(rawSuffix); ok {
if budget == 0 {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeBudget, Budget: budget}
}
// Unknown suffix format - return empty config
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"raw_suffix": rawSuffix,
}).Debug("thinking: unknown suffix format, treating as no config |")
return ThinkingConfig{}
}
// applyUserDefinedModel applies thinking configuration for user-defined models
// without ThinkingSupport validation.
func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromFormat, toFormat string, suffixResult SuffixResult) ([]byte, error) {
// Get model ID for logging
modelID := ""
if modelInfo != nil {
modelID = modelInfo.ID
} else {
modelID = suffixResult.ModelName
}
// Get config: suffix priority over body
var config ThinkingConfig
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else {
config = extractThinkingConfig(body, toFormat)
}
if !hasThinkingConfig(config) {
log.WithFields(log.Fields{
"model": modelID,
"provider": toFormat,
}).Debug("thinking: user-defined model, passthrough (no config) |")
return body, nil
}
applier := GetProviderApplier(toFormat)
if applier == nil {
log.WithFields(log.Fields{
"model": modelID,
"provider": toFormat,
}).Debug("thinking: user-defined model, passthrough (unknown provider) |")
return body, nil
}
log.WithFields(log.Fields{
"provider": toFormat,
"model": modelID,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: applying config for user-defined model (skip validation)")
config = normalizeUserDefinedConfig(config, fromFormat, toFormat)
return applier.Apply(body, config, modelInfo)
}
func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat string) ThinkingConfig {
if config.Mode != ModeLevel {
return config
}
if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) {
return config
}
budget, ok := ConvertLevelToBudget(string(config.Level))
if !ok {
return config
}
config.Mode = ModeBudget
config.Budget = budget
config.Level = ""
return config
}
// extractThinkingConfig extracts provider-specific thinking config from request body.
func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
if len(body) == 0 || !gjson.ValidBytes(body) {
return ThinkingConfig{}
}
switch provider {
case "claude":
return extractClaudeConfig(body)
case "gemini", "gemini-cli", "antigravity":
return extractGeminiConfig(body, provider)
case "openai":
return extractOpenAIConfig(body)
case "codex":
return extractCodexConfig(body)
case "iflow":
config := extractIFlowConfig(body)
if hasThinkingConfig(config) {
return config
}
return extractOpenAIConfig(body)
default:
return ThinkingConfig{}
}
}
func hasThinkingConfig(config ThinkingConfig) bool {
return config.Mode != ModeBudget || config.Budget != 0 || config.Level != ""
}
// extractClaudeConfig extracts thinking configuration from Claude format request body.
//
// Claude API format:
// - thinking.type: "enabled" or "disabled"
// - thinking.budget_tokens: integer (-1=auto, 0=disabled, >0=budget)
//
// Priority: thinking.type="disabled" takes precedence over budget_tokens.
// When type="enabled" without budget_tokens, returns ModeAuto to indicate
// the user wants thinking enabled but didn't specify a budget.
func extractClaudeConfig(body []byte) ThinkingConfig {
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if thinkingType == "disabled" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
// Check budget_tokens
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
value := int(budget.Int())
switch value {
case 0:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case -1:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeBudget, Budget: value}
}
}
// If type="enabled" but no budget_tokens, treat as auto (user wants thinking but no budget specified)
if thinkingType == "enabled" {
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
}
return ThinkingConfig{}
}
// extractGeminiConfig extracts thinking configuration from Gemini format request body.
//
// Gemini API format:
// - generationConfig.thinkingConfig.thinkingLevel: "none", "auto", or level name (Gemini 3)
// - generationConfig.thinkingConfig.thinkingBudget: integer (Gemini 2.5)
//
// For gemini-cli and antigravity providers, the path is prefixed with "request.".
//
// Priority: thinkingLevel is checked first (Gemini 3 format), then thinkingBudget (Gemini 2.5 format).
// This allows newer Gemini 3 level-based configs to take precedence.
func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
prefix := "generationConfig.thinkingConfig"
if provider == "gemini-cli" || provider == "antigravity" {
prefix = "request.generationConfig.thinkingConfig"
}
// Check thinkingLevel first (Gemini 3 format takes precedence)
if level := gjson.GetBytes(body, prefix+".thinkingLevel"); level.Exists() {
value := level.String()
switch value {
case "none":
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case "auto":
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
}
// Check thinkingBudget (Gemini 2.5 format)
if budget := gjson.GetBytes(body, prefix+".thinkingBudget"); budget.Exists() {
value := int(budget.Int())
switch value {
case 0:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case -1:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeBudget, Budget: value}
}
}
return ThinkingConfig{}
}
// extractOpenAIConfig extracts thinking configuration from OpenAI format request body.
//
// OpenAI API format:
// - reasoning_effort: "none", "low", "medium", "high" (discrete levels)
//
// OpenAI uses level-based thinking configuration only, no numeric budget support.
// The "none" value is treated specially to return ModeNone.
func extractOpenAIConfig(body []byte) ThinkingConfig {
// Check reasoning_effort (OpenAI Chat Completions format)
if effort := gjson.GetBytes(body, "reasoning_effort"); effort.Exists() {
value := effort.String()
if value == "none" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
return ThinkingConfig{}
}
// extractCodexConfig extracts thinking configuration from Codex format request body.
//
// Codex API format (OpenAI Responses API):
// - reasoning.effort: "none", "low", "medium", "high"
//
// This is similar to OpenAI but uses nested field "reasoning.effort" instead of "reasoning_effort".
func extractCodexConfig(body []byte) ThinkingConfig {
// Check reasoning.effort (Codex / OpenAI Responses API format)
if effort := gjson.GetBytes(body, "reasoning.effort"); effort.Exists() {
value := effort.String()
if value == "none" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
return ThinkingConfig{}
}
// extractIFlowConfig extracts thinking configuration from iFlow format request body.
//
// iFlow API format (supports multiple model families):
// - GLM format: chat_template_kwargs.enable_thinking (boolean)
// - MiniMax format: reasoning_split (boolean)
//
// Returns ModeBudget with Budget=1 as a sentinel value indicating "enabled".
// The actual budget/configuration is determined by the iFlow applier based on model capabilities.
// Budget=1 is used because iFlow models don't use numeric budgets; they only support on/off.
func extractIFlowConfig(body []byte) ThinkingConfig {
// GLM format: chat_template_kwargs.enable_thinking
if enabled := gjson.GetBytes(body, "chat_template_kwargs.enable_thinking"); enabled.Exists() {
if enabled.Bool() {
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
}
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
// MiniMax format: reasoning_split
if split := gjson.GetBytes(body, "reasoning_split"); split.Exists() {
if split.Bool() {
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
}
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{}
}

View File

@@ -0,0 +1,142 @@
package thinking
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// levelToBudgetMap defines the standard Level → Budget mapping.
// All keys are lowercase; lookups should use strings.ToLower.
var levelToBudgetMap = map[string]int{
"none": 0,
"auto": -1,
"minimal": 512,
"low": 1024,
"medium": 8192,
"high": 24576,
"xhigh": 32768,
}
// ConvertLevelToBudget converts a thinking level to a budget value.
//
// This is a semantic conversion that maps discrete levels to numeric budgets.
// Level matching is case-insensitive.
//
// Level → Budget mapping:
// - none → 0
// - auto → -1
// - minimal → 512
// - low → 1024
// - medium → 8192
// - high → 24576
// - xhigh → 32768
//
// Returns:
// - budget: The converted budget value
// - ok: true if level is valid, false otherwise
func ConvertLevelToBudget(level string) (int, bool) {
budget, ok := levelToBudgetMap[strings.ToLower(level)]
return budget, ok
}
// BudgetThreshold constants define the upper bounds for each thinking level.
// These are used by ConvertBudgetToLevel for range-based mapping.
const (
// ThresholdMinimal is the upper bound for "minimal" level (1-512)
ThresholdMinimal = 512
// ThresholdLow is the upper bound for "low" level (513-1024)
ThresholdLow = 1024
// ThresholdMedium is the upper bound for "medium" level (1025-8192)
ThresholdMedium = 8192
// ThresholdHigh is the upper bound for "high" level (8193-24576)
ThresholdHigh = 24576
)
// ConvertBudgetToLevel converts a budget value to the nearest thinking level.
//
// This is a semantic conversion that maps numeric budgets to discrete levels.
// Uses threshold-based mapping for range conversion.
//
// Budget → Level thresholds:
// - -1 → auto
// - 0 → none
// - 1-512 → minimal
// - 513-1024 → low
// - 1025-8192 → medium
// - 8193-24576 → high
// - 24577+ → xhigh
//
// Returns:
// - level: The converted thinking level string
// - ok: true if budget is valid, false for invalid negatives (< -1)
func ConvertBudgetToLevel(budget int) (string, bool) {
switch {
case budget < -1:
// Invalid negative values
return "", false
case budget == -1:
return string(LevelAuto), true
case budget == 0:
return string(LevelNone), true
case budget <= ThresholdMinimal:
return string(LevelMinimal), true
case budget <= ThresholdLow:
return string(LevelLow), true
case budget <= ThresholdMedium:
return string(LevelMedium), true
case budget <= ThresholdHigh:
return string(LevelHigh), true
default:
return string(LevelXHigh), true
}
}
// ModelCapability describes the thinking format support of a model.
type ModelCapability int
const (
// CapabilityUnknown indicates modelInfo is nil (passthrough behavior, internal use).
CapabilityUnknown ModelCapability = iota - 1
// CapabilityNone indicates model doesn't support thinking (Thinking is nil).
CapabilityNone
// CapabilityBudgetOnly indicates the model supports numeric budgets only.
CapabilityBudgetOnly
// CapabilityLevelOnly indicates the model supports discrete levels only.
CapabilityLevelOnly
// CapabilityHybrid indicates the model supports both budgets and levels.
CapabilityHybrid
)
// detectModelCapability determines the thinking format capability of a model.
//
// This is an internal function used by validation and conversion helpers.
// It analyzes the model's ThinkingSupport configuration to classify the model:
// - CapabilityNone: modelInfo.Thinking is nil (model doesn't support thinking)
// - CapabilityBudgetOnly: Has Min/Max but no Levels (Claude, Gemini 2.5)
// - CapabilityLevelOnly: Has Levels but no Min/Max (OpenAI, iFlow)
// - CapabilityHybrid: Has both Min/Max and Levels (Gemini 3)
//
// Note: Returns a special sentinel value when modelInfo itself is nil (unknown model).
func detectModelCapability(modelInfo *registry.ModelInfo) ModelCapability {
if modelInfo == nil {
return CapabilityUnknown // sentinel for "passthrough" behavior
}
if modelInfo.Thinking == nil {
return CapabilityNone
}
support := modelInfo.Thinking
hasBudget := support.Min > 0 || support.Max > 0
hasLevels := len(support.Levels) > 0
switch {
case hasBudget && hasLevels:
return CapabilityHybrid
case hasBudget:
return CapabilityBudgetOnly
case hasLevels:
return CapabilityLevelOnly
default:
return CapabilityNone
}
}

View File

@@ -0,0 +1,82 @@
// Package thinking provides unified thinking configuration processing logic.
package thinking
import "net/http"
// ErrorCode represents the type of thinking configuration error.
type ErrorCode string
// Error codes for thinking configuration processing.
const (
// ErrInvalidSuffix indicates the suffix format cannot be parsed.
// Example: "model(abc" (missing closing parenthesis)
ErrInvalidSuffix ErrorCode = "INVALID_SUFFIX"
// ErrUnknownLevel indicates the level value is not in the valid list.
// Example: "model(ultra)" where "ultra" is not a valid level
ErrUnknownLevel ErrorCode = "UNKNOWN_LEVEL"
// ErrThinkingNotSupported indicates the model does not support thinking.
// Example: claude-haiku-4-5 does not have thinking capability
ErrThinkingNotSupported ErrorCode = "THINKING_NOT_SUPPORTED"
// ErrLevelNotSupported indicates the model does not support level mode.
// Example: using level with a budget-only model
ErrLevelNotSupported ErrorCode = "LEVEL_NOT_SUPPORTED"
// ErrBudgetOutOfRange indicates the budget value is outside model range.
// Example: budget 64000 exceeds max 20000
ErrBudgetOutOfRange ErrorCode = "BUDGET_OUT_OF_RANGE"
// ErrProviderMismatch indicates the provider does not match the model.
// Example: applying Claude format to a Gemini model
ErrProviderMismatch ErrorCode = "PROVIDER_MISMATCH"
)
// ThinkingError represents an error that occurred during thinking configuration processing.
//
// This error type provides structured information about the error, including:
// - Code: A machine-readable error code for programmatic handling
// - Message: A human-readable description of the error
// - Model: The model name related to the error (optional)
// - Details: Additional context information (optional)
type ThinkingError struct {
// Code is the machine-readable error code
Code ErrorCode
// Message is the human-readable error description.
// Should be lowercase, no trailing period, with context if applicable.
Message string
// Model is the model name related to this error (optional)
Model string
// Details contains additional context information (optional)
Details map[string]interface{}
}
// Error implements the error interface.
// Returns the message directly without code prefix.
// Use Code field for programmatic error handling.
func (e *ThinkingError) Error() string {
return e.Message
}
// NewThinkingError creates a new ThinkingError with the given code and message.
func NewThinkingError(code ErrorCode, message string) *ThinkingError {
return &ThinkingError{
Code: code,
Message: message,
}
}
// NewThinkingErrorWithModel creates a new ThinkingError with model context.
func NewThinkingErrorWithModel(code ErrorCode, message, model string) *ThinkingError {
return &ThinkingError{
Code: code,
Message: message,
Model: model,
}
}
// StatusCode implements a portable status code interface for HTTP handlers.
func (e *ThinkingError) StatusCode() int {
return http.StatusBadRequest
}

View File

@@ -0,0 +1,201 @@
// Package antigravity implements thinking configuration for Antigravity API format.
//
// Antigravity uses request.generationConfig.thinkingConfig.* path (same as gemini-cli)
// but requires additional normalization for Claude models:
// - Ensure thinking budget < max_tokens
// - Remove thinkingConfig if budget < minimum allowed
package antigravity
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Antigravity API format.
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Antigravity thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("antigravity", NewApplier())
}
// Apply applies thinking configuration to Antigravity request body.
//
// For Claude models, additional constraints are applied:
// - Ensure thinking budget < max_tokens
// - Remove thinkingConfig if budget < minimum allowed
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config, modelInfo)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
isClaude := strings.Contains(strings.ToLower(modelInfo.ID), "claude")
// ModeAuto: Always use Budget format with thinkingBudget=-1
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
if config.Mode == thinking.ModeBudget {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
// For non-auto modes, choose format based on model capabilities
support := modelInfo.Thinking
if len(support.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
isClaude := false
if modelInfo != nil {
isClaude = strings.Contains(strings.ToLower(modelInfo.ID), "claude")
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
includeThoughts := false
switch config.Mode {
case thinking.ModeNone:
includeThoughts = false
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
// Apply Claude-specific constraints
if isClaude && modelInfo != nil {
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
// Check if budget was removed entirely
if budget == -2 {
return result, nil
}
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
// normalizeClaudeBudget applies Claude-specific constraints to thinking budget.
//
// It handles:
// - Ensuring thinking budget < max_tokens
// - Removing thinkingConfig if budget < minimum allowed
//
// Returns the normalized budget and updated payload.
// Returns budget=-2 as a sentinel indicating thinkingConfig was removed entirely.
func (a *Applier) normalizeClaudeBudget(budget int, payload []byte, modelInfo *registry.ModelInfo) (int, []byte) {
if modelInfo == nil {
return budget, payload
}
// Get effective max tokens
effectiveMax, setDefaultMax := a.effectiveMaxTokens(payload, modelInfo)
if effectiveMax > 0 && budget >= effectiveMax {
budget = effectiveMax - 1
}
// Check minimum budget
minBudget := 0
if modelInfo.Thinking != nil {
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && budget >= 0 && budget < minBudget {
// Budget is below minimum, remove thinking config entirely
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
return -2, payload
}
// Set default max tokens if needed
if setDefaultMax && effectiveMax > 0 {
payload, _ = sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax)
}
return budget, payload
}
// effectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func (a *Applier) effectiveMaxTokens(payload []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}

View File

@@ -0,0 +1,166 @@
// Package claude implements thinking configuration scaffolding for Claude models.
//
// Claude models use the thinking.budget_tokens format with values in the range
// 1024-128000. Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5),
// while older models do not.
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
package claude
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for Claude models.
// This applier is stateless and holds no configuration.
type Applier struct{}
// NewApplier creates a new Claude thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("claude", NewApplier())
}
// Apply applies thinking configuration to Claude request body.
//
// IMPORTANT: This method expects config to be pre-validated by thinking.ValidateConfig.
// ValidateConfig handles:
// - Mode conversion (Level→Budget, Auto→Budget)
// - Budget clamping to model range
// - ZeroAllowed constraint enforcement
//
// Apply only processes ModeBudget and ModeNone; other modes are passed through unchanged.
//
// Expected output format when enabled:
//
// {
// "thinking": {
// "type": "enabled",
// "budget_tokens": 16384
// }
// }
//
// Expected output format when disabled:
//
// {
// "thinking": {
// "type": "disabled"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleClaude(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
// Only process ModeBudget and ModeNone; other modes pass through
// (caller should use ValidateConfig first to normalize modes)
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced)
// Decide enabled/disabled based on budget value
if config.Budget == 0 {
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
return result, nil
}
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
return result, nil
}
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
// Anthropic API requires this constraint; violating it returns a 400 error.
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
if budgetTokens <= 0 {
return body
}
// Ensure the request satisfies Claude constraints:
// 1) Determine effective max_tokens (request overrides model default)
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
// 4) If max_tokens came from model default, write it back into the request
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
if setDefaultMax && effectiveMax > 0 {
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
}
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
adjustedBudget := budgetTokens
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
adjustedBudget = effectiveMax - 1
}
minBudget := 0
if modelInfo != nil && modelInfo.Thinking != nil {
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
// If enforcing the max_tokens constraint would push the budget below the model minimum,
// leave the request unchanged.
return body
}
if adjustedBudget != budgetTokens {
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
}
return body
}
// effectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided max_tokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
switch config.Mode {
case thinking.ModeNone:
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
return result, nil
case thinking.ModeAuto:
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
return result, nil
default:
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
return result, nil
}
}

View File

@@ -0,0 +1,131 @@
// Package codex implements thinking configuration for Codex (OpenAI Responses API) models.
//
// Codex models use the reasoning.effort format with discrete levels
// (low/medium/high). This is similar to OpenAI but uses nested field
// "reasoning.effort" instead of "reasoning_effort".
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
package codex
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for Codex models.
//
// Codex-specific behavior:
// - Output format: reasoning.effort (string: low/medium/high/xhigh)
// - Level-only mode: no numeric budget support
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Codex thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("codex", NewApplier())
}
// Apply applies thinking configuration to Codex request body.
//
// Expected output format:
//
// {
// "reasoning": {
// "effort": "high"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleCodex(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning.effort", string(config.Level))
return result, nil
}
effort := ""
support := modelInfo.Thinking
if config.Budget == 0 {
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
effort = string(thinking.LevelNone)
}
}
if effort == "" && config.Level != "" {
effort = string(config.Level)
}
if effort == "" && len(support.Levels) > 0 {
effort = support.Levels[0]
}
if effort == "" {
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
return result, nil
}
func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
effort = string(thinking.LevelNone)
if config.Level != "" {
effort = string(config.Level)
}
case thinking.ModeAuto:
// Auto mode for user-defined models: pass through as "auto"
effort = string(thinking.LevelAuto)
case thinking.ModeBudget:
// Budget mode: convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
default:
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
return result, nil
}
func hasLevel(levels []string, target string) bool {
for _, level := range levels {
if strings.EqualFold(strings.TrimSpace(level), target) {
return true
}
}
return false
}

View File

@@ -0,0 +1,169 @@
// Package gemini implements thinking configuration for Gemini models.
//
// Gemini models have two formats:
// - Gemini 2.5: Uses thinkingBudget (numeric)
// - Gemini 3.x: Uses thinkingLevel (string: minimal/low/medium/high)
// or thinkingBudget=-1 for auto/dynamic mode
//
// Output format is determined by ThinkingConfig.Mode and ThinkingSupport.Levels:
// - ModeAuto: Always uses thinkingBudget=-1 (both Gemini 2.5 and 3.x)
// - len(Levels) > 0: Uses thinkingLevel (Gemini 3.x discrete levels)
// - len(Levels) == 0: Uses thinkingBudget (Gemini 2.5)
package gemini
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Gemini models.
//
// Gemini-specific behavior:
// - Gemini 2.5: thinkingBudget format, flash series supports ZeroAllowed
// - Gemini 3.x: thinkingLevel format, cannot be disabled
// - Use ThinkingSupport.Levels to decide output format
type Applier struct{}
// NewApplier creates a new Gemini thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("gemini", NewApplier())
}
// Apply applies thinking configuration to Gemini request body.
//
// Expected output format (Gemini 2.5):
//
// {
// "generationConfig": {
// "thinkingConfig": {
// "thinkingBudget": 8192,
// "includeThoughts": true
// }
// }
// }
//
// Expected output format (Gemini 3.x):
//
// {
// "generationConfig": {
// "thinkingConfig": {
// "thinkingLevel": "high",
// "includeThoughts": true
// }
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
// Choose format based on config.Mode and model capabilities:
// - ModeLevel: use Level format (validation will reject unsupported levels)
// - ModeNone: use Level format if model has Levels, else Budget format
// - ModeBudget/ModeAuto: use Budget format
switch config.Mode {
case thinking.ModeLevel:
return a.applyLevelFormat(body, config)
case thinking.ModeNone:
// ModeNone: route based on model capability (has Levels or not)
if len(modelInfo.Thinking.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
default:
return a.applyBudgetFormat(body, config)
}
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// ModeNone semantics:
// - ModeNone + Budget=0: completely disable thinking (not possible for Level-only models)
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", true)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
// ModeNone semantics:
// - ModeNone + Budget=0: completely disable thinking
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
// When ZeroAllowed=false, ValidateConfig clamps Budget to Min while preserving ModeNone.
includeThoughts := false
switch config.Mode {
case thinking.ModeNone:
includeThoughts = false
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}

View File

@@ -0,0 +1,126 @@
// Package geminicli implements thinking configuration for Gemini CLI API format.
//
// Gemini CLI uses request.generationConfig.thinkingConfig.* path instead of
// generationConfig.thinkingConfig.* used by standard Gemini API.
package geminicli
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Gemini CLI API format.
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Gemini CLI thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("gemini-cli", NewApplier())
}
// Apply applies thinking configuration to Gemini CLI request body.
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
// ModeAuto: Always use Budget format with thinkingBudget=-1
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeBudget {
return a.applyBudgetFormat(body, config)
}
// For non-auto modes, choose format based on model capabilities
support := modelInfo.Thinking
if len(support.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
includeThoughts := false
switch config.Mode {
case thinking.ModeNone:
includeThoughts = false
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}

View File

@@ -0,0 +1,156 @@
// Package iflow implements thinking configuration for iFlow models (GLM, MiniMax).
//
// iFlow models use boolean toggle semantics:
// - GLM models: chat_template_kwargs.enable_thinking (boolean)
// - MiniMax models: reasoning_split (boolean)
//
// Level values are converted to boolean: none=false, all others=true
// See: _bmad-output/planning-artifacts/architecture.md#Epic-9
package iflow
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for iFlow models.
//
// iFlow-specific behavior:
// - GLM models: enable_thinking boolean + clear_thinking=false
// - MiniMax models: reasoning_split boolean
// - Level to boolean: none=false, others=true
// - No quantized support (only on/off)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new iFlow thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("iflow", NewApplier())
}
// Apply applies thinking configuration to iFlow request body.
//
// Expected output format (GLM):
//
// {
// "chat_template_kwargs": {
// "enable_thinking": true,
// "clear_thinking": false
// }
// }
//
// Expected output format (MiniMax):
//
// {
// "reasoning_split": true
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return body, nil
}
if modelInfo.Thinking == nil {
return body, nil
}
if isGLMModel(modelInfo.ID) {
return applyGLM(body, config), nil
}
if isMiniMaxModel(modelInfo.ID) {
return applyMiniMax(body, config), nil
}
return body, nil
}
// configToBoolean converts ThinkingConfig to boolean for iFlow models.
//
// Conversion rules:
// - ModeNone: false
// - ModeAuto: true
// - ModeBudget + Budget=0: false
// - ModeBudget + Budget>0: true
// - ModeLevel + Level="none": false
// - ModeLevel + any other level: true
// - Default (unknown mode): true
func configToBoolean(config thinking.ThinkingConfig) bool {
switch config.Mode {
case thinking.ModeNone:
return false
case thinking.ModeAuto:
return true
case thinking.ModeBudget:
return config.Budget > 0
case thinking.ModeLevel:
return config.Level != thinking.LevelNone
default:
return true
}
}
// applyGLM applies thinking configuration for GLM models.
//
// Output format when enabled:
//
// {"chat_template_kwargs": {"enable_thinking": true, "clear_thinking": false}}
//
// Output format when disabled:
//
// {"chat_template_kwargs": {"enable_thinking": false}}
//
// Note: clear_thinking is only set when thinking is enabled, to preserve
// thinking output in the response.
func applyGLM(body []byte, config thinking.ThinkingConfig) []byte {
enableThinking := configToBoolean(config)
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
result, _ := sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
// clear_thinking only needed when thinking is enabled
if enableThinking {
result, _ = sjson.SetBytes(result, "chat_template_kwargs.clear_thinking", false)
}
return result
}
// applyMiniMax applies thinking configuration for MiniMax models.
//
// Output format:
//
// {"reasoning_split": true/false}
func applyMiniMax(body []byte, config thinking.ThinkingConfig) []byte {
reasoningSplit := configToBoolean(config)
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
result, _ := sjson.SetBytes(body, "reasoning_split", reasoningSplit)
return result
}
// isGLMModel determines if the model is a GLM series model.
// GLM models use chat_template_kwargs.enable_thinking format.
func isGLMModel(modelID string) bool {
return strings.HasPrefix(strings.ToLower(modelID), "glm")
}
// isMiniMaxModel determines if the model is a MiniMax series model.
// MiniMax models use reasoning_split format.
func isMiniMaxModel(modelID string) bool {
return strings.HasPrefix(strings.ToLower(modelID), "minimax")
}

View File

@@ -0,0 +1,128 @@
// Package openai implements thinking configuration for OpenAI/Codex models.
//
// OpenAI models use the reasoning_effort format with discrete levels
// (low/medium/high). Some models support xhigh and none levels.
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
package openai
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
// - Output format: reasoning_effort (string: low/medium/high/xhigh)
// - Level-only mode: no numeric budget support
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new OpenAI thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("openai", NewApplier())
}
// Apply applies thinking configuration to OpenAI request body.
//
// Expected output format:
//
// {
// "reasoning_effort": "high"
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleOpenAI(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
return result, nil
}
effort := ""
support := modelInfo.Thinking
if config.Budget == 0 {
if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) {
effort = string(thinking.LevelNone)
}
}
if effort == "" && config.Level != "" {
effort = string(config.Level)
}
if effort == "" && len(support.Levels) > 0 {
effort = support.Levels[0]
}
if effort == "" {
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
effort = string(thinking.LevelNone)
if config.Level != "" {
effort = string(config.Level)
}
case thinking.ModeAuto:
// Auto mode for user-defined models: pass through as "auto"
effort = string(thinking.LevelAuto)
case thinking.ModeBudget:
// Budget mode: convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
default:
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
func hasLevel(levels []string, target string) bool {
for _, level := range levels {
if strings.EqualFold(strings.TrimSpace(level), target) {
return true
}
}
return false
}

View File

@@ -0,0 +1,58 @@
// Package thinking provides unified thinking configuration processing.
package thinking
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// StripThinkingConfig removes thinking configuration fields from request body.
//
// This function is used when a model doesn't support thinking but the request
// contains thinking configuration. The configuration is silently removed to
// prevent upstream API errors.
//
// Parameters:
// - body: Original request body JSON
// - provider: Provider name (determines which fields to strip)
//
// Returns:
// - Modified request body JSON with thinking configuration removed
// - Original body is returned unchanged if:
// - body is empty or invalid JSON
// - provider is unknown
// - no thinking configuration found
func StripThinkingConfig(body []byte, provider string) []byte {
if len(body) == 0 || !gjson.ValidBytes(body) {
return body
}
var paths []string
switch provider {
case "claude":
paths = []string{"thinking"}
case "gemini":
paths = []string{"generationConfig.thinkingConfig"}
case "gemini-cli", "antigravity":
paths = []string{"request.generationConfig.thinkingConfig"}
case "openai":
paths = []string{"reasoning_effort"}
case "codex":
paths = []string{"reasoning.effort"}
case "iflow":
paths = []string{
"chat_template_kwargs.enable_thinking",
"chat_template_kwargs.clear_thinking",
"reasoning_split",
"reasoning_effort",
}
default:
return body
}
result := body
for _, path := range paths {
result, _ = sjson.DeleteBytes(result, path)
}
return result
}

146
internal/thinking/suffix.go Normal file
View File

@@ -0,0 +1,146 @@
// Package thinking provides unified thinking configuration processing.
//
// This file implements suffix parsing functionality for extracting
// thinking configuration from model names in the format model(value).
package thinking
import (
"strconv"
"strings"
)
// ParseSuffix extracts thinking suffix from a model name.
//
// The suffix format is: model-name(value)
// Examples:
// - "claude-sonnet-4-5(16384)" -> ModelName="claude-sonnet-4-5", RawSuffix="16384"
// - "gpt-5.2(high)" -> ModelName="gpt-5.2", RawSuffix="high"
// - "gemini-2.5-pro" -> ModelName="gemini-2.5-pro", HasSuffix=false
//
// This function only extracts the suffix; it does not validate or interpret
// the suffix content. Use ParseNumericSuffix, ParseLevelSuffix, etc. for
// content interpretation.
func ParseSuffix(model string) SuffixResult {
// Find the last opening parenthesis
lastOpen := strings.LastIndex(model, "(")
if lastOpen == -1 {
return SuffixResult{ModelName: model, HasSuffix: false}
}
// Check if the string ends with a closing parenthesis
if !strings.HasSuffix(model, ")") {
return SuffixResult{ModelName: model, HasSuffix: false}
}
// Extract components
modelName := model[:lastOpen]
rawSuffix := model[lastOpen+1 : len(model)-1]
return SuffixResult{
ModelName: modelName,
HasSuffix: true,
RawSuffix: rawSuffix,
}
}
// ParseNumericSuffix attempts to parse a raw suffix as a numeric budget value.
//
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as an integer.
// Only non-negative integers are considered valid numeric suffixes.
//
// Platform note: The budget value uses Go's int type, which is 32-bit on 32-bit
// systems and 64-bit on 64-bit systems. Values exceeding the platform's int range
// will return ok=false.
//
// Leading zeros are accepted: "08192" parses as 8192.
//
// Examples:
// - "8192" -> budget=8192, ok=true
// - "0" -> budget=0, ok=true (represents ModeNone)
// - "08192" -> budget=8192, ok=true (leading zeros accepted)
// - "-1" -> budget=0, ok=false (negative numbers are not valid numeric suffixes)
// - "high" -> budget=0, ok=false (not a number)
// - "9223372036854775808" -> budget=0, ok=false (overflow on 64-bit systems)
//
// For special handling of -1 as auto mode, use ParseSpecialSuffix instead.
func ParseNumericSuffix(rawSuffix string) (budget int, ok bool) {
if rawSuffix == "" {
return 0, false
}
value, err := strconv.Atoi(rawSuffix)
if err != nil {
return 0, false
}
// Negative numbers are not valid numeric suffixes
// -1 should be handled by special value parsing as "auto"
if value < 0 {
return 0, false
}
return value, true
}
// ParseSpecialSuffix attempts to parse a raw suffix as a special thinking mode value.
//
// This function handles special strings that represent a change in thinking mode:
// - "none" -> ModeNone (disables thinking)
// - "auto" -> ModeAuto (automatic/dynamic thinking)
// - "-1" -> ModeAuto (numeric representation of auto mode)
//
// String values are case-insensitive.
func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) {
if rawSuffix == "" {
return ModeBudget, false
}
// Case-insensitive matching
switch strings.ToLower(rawSuffix) {
case "none":
return ModeNone, true
case "auto", "-1":
return ModeAuto, true
default:
return ModeBudget, false
}
}
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
//
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh.
// Level matching is case-insensitive.
//
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
// instead. This separation allows callers to prioritize special value handling.
//
// Examples:
// - "high" -> level=LevelHigh, ok=true
// - "HIGH" -> level=LevelHigh, ok=true (case insensitive)
// - "medium" -> level=LevelMedium, ok=true
// - "none" -> level="", ok=false (special value, use ParseSpecialSuffix)
// - "auto" -> level="", ok=false (special value, use ParseSpecialSuffix)
// - "8192" -> level="", ok=false (numeric, use ParseNumericSuffix)
// - "ultra" -> level="", ok=false (unknown level)
func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) {
if rawSuffix == "" {
return "", false
}
// Case-insensitive matching
switch strings.ToLower(rawSuffix) {
case "minimal":
return LevelMinimal, true
case "low":
return LevelLow, true
case "medium":
return LevelMedium, true
case "high":
return LevelHigh, true
case "xhigh":
return LevelXHigh, true
default:
return "", false
}
}

41
internal/thinking/text.go Normal file
View File

@@ -0,0 +1,41 @@
package thinking
import (
"github.com/tidwall/gjson"
)
// GetThinkingText extracts the thinking text from a content part.
// Handles various formats:
// - Simple string: { "thinking": "text" } or { "text": "text" }
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
// - Gemini-style: { "thought": true, "text": "text" }
// Returns the extracted text string.
func GetThinkingText(part gjson.Result) string {
// Try direct text field first (Gemini-style)
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
// Try thinking field
thinkingField := part.Get("thinking")
if !thinkingField.Exists() {
return ""
}
// thinking is a string
if thinkingField.Type == gjson.String {
return thinkingField.String()
}
// thinking is an object with inner text/thinking
if thinkingField.IsObject() {
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
}
return ""
}

116
internal/thinking/types.go Normal file
View File

@@ -0,0 +1,116 @@
// Package thinking provides unified thinking configuration processing.
//
// This package offers a unified interface for parsing, validating, and applying
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, iFlow).
package thinking
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
// ThinkingMode represents the type of thinking configuration mode.
type ThinkingMode int
const (
// ModeBudget indicates using a numeric budget (corresponds to suffix "(1000)" etc.)
ModeBudget ThinkingMode = iota
// ModeLevel indicates using a discrete level (corresponds to suffix "(high)" etc.)
ModeLevel
// ModeNone indicates thinking is disabled (corresponds to suffix "(none)" or budget=0)
ModeNone
// ModeAuto indicates automatic/dynamic thinking (corresponds to suffix "(auto)" or budget=-1)
ModeAuto
)
// String returns the string representation of ThinkingMode.
func (m ThinkingMode) String() string {
switch m {
case ModeBudget:
return "budget"
case ModeLevel:
return "level"
case ModeNone:
return "none"
case ModeAuto:
return "auto"
default:
return "unknown"
}
}
// ThinkingLevel represents a discrete thinking level.
type ThinkingLevel string
const (
// LevelNone disables thinking
LevelNone ThinkingLevel = "none"
// LevelAuto enables automatic/dynamic thinking
LevelAuto ThinkingLevel = "auto"
// LevelMinimal sets minimal thinking effort
LevelMinimal ThinkingLevel = "minimal"
// LevelLow sets low thinking effort
LevelLow ThinkingLevel = "low"
// LevelMedium sets medium thinking effort
LevelMedium ThinkingLevel = "medium"
// LevelHigh sets high thinking effort
LevelHigh ThinkingLevel = "high"
// LevelXHigh sets extra-high thinking effort
LevelXHigh ThinkingLevel = "xhigh"
)
// ThinkingConfig represents a unified thinking configuration.
//
// This struct is used to pass thinking configuration information between components.
// Depending on Mode, either Budget or Level field is effective:
// - ModeNone: Budget=0, Level is ignored
// - ModeAuto: Budget=-1, Level is ignored
// - ModeBudget: Budget is a positive integer, Level is ignored
// - ModeLevel: Budget is ignored, Level is a valid level
type ThinkingConfig struct {
// Mode specifies the configuration mode
Mode ThinkingMode
// Budget is the thinking budget (token count), only effective when Mode is ModeBudget.
// Special values: 0 means disabled, -1 means automatic
Budget int
// Level is the thinking level, only effective when Mode is ModeLevel
Level ThinkingLevel
}
// SuffixResult represents the result of parsing a model name for thinking suffix.
//
// A thinking suffix is specified in the format model-name(value), where value
// can be a numeric budget (e.g., "16384") or a level name (e.g., "high").
type SuffixResult struct {
// ModelName is the model name with the suffix removed.
// If no suffix was found, this equals the original input.
ModelName string
// HasSuffix indicates whether a valid suffix was found.
HasSuffix bool
// RawSuffix is the content inside the parentheses, without the parentheses.
// Empty string if HasSuffix is false.
RawSuffix string
}
// ProviderApplier defines the interface for provider-specific thinking configuration application.
//
// Types implementing this interface are responsible for converting a unified ThinkingConfig
// into provider-specific format and applying it to the request body.
//
// Implementation requirements:
// - Apply method must be idempotent
// - Must not modify the input config or modelInfo
// - Returns a modified copy of the request body
// - Returns appropriate ThinkingError for unsupported configurations
type ProviderApplier interface {
// Apply applies the thinking configuration to the request body.
//
// Parameters:
// - body: Original request body JSON
// - config: Unified thinking configuration
// - modelInfo: Model registry information containing ThinkingSupport properties
//
// Returns:
// - Modified request body JSON
// - ThinkingError if the configuration is invalid or unsupported
Apply(body []byte, config ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error)
}

View File

@@ -0,0 +1,378 @@
// Package thinking provides unified thinking configuration processing logic.
package thinking
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
)
// ValidateConfig validates a thinking configuration against model capabilities.
//
// This function performs comprehensive validation:
// - Checks if the model supports thinking
// - Auto-converts between Budget and Level formats based on model capability
// - Validates that requested level is in the model's supported levels list
// - Clamps budget values to model's allowed range
// - When converting Budget -> Level for level-only models, clamps the derived standard level to the nearest supported level
// (special values none/auto are preserved)
// - When config comes from a model suffix, strict budget validation is disabled (we clamp instead of error)
//
// Parameters:
// - config: The thinking configuration to validate
// - support: Model's ThinkingSupport properties (nil means no thinking support)
// - fromFormat: Source provider format (used to determine strict validation rules)
// - toFormat: Target provider format
// - fromSuffix: Whether config was sourced from model suffix
//
// Returns:
// - Normalized ThinkingConfig with clamped values
// - ThinkingError if validation fails (ErrThinkingNotSupported, ErrLevelNotSupported, etc.)
//
// Auto-conversion behavior:
// - Budget-only model + Level config → Level converted to Budget
// - Level-only model + Budget config → Budget converted to Level
// - Hybrid model → preserve original format
func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFormat, toFormat string, fromSuffix bool) (*ThinkingConfig, error) {
fromFormat, toFormat = strings.ToLower(strings.TrimSpace(fromFormat)), strings.ToLower(strings.TrimSpace(toFormat))
model := "unknown"
support := (*registry.ThinkingSupport)(nil)
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
support = modelInfo.Thinking
}
if support == nil {
if config.Mode != ModeNone {
return nil, NewThinkingErrorWithModel(ErrThinkingNotSupported, "thinking not supported for this model", model)
}
return &config, nil
}
allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat)
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
budgetDerivedFromLevel := false
capability := detectModelCapability(modelInfo)
switch capability {
case CapabilityBudgetOnly:
if config.Mode == ModeLevel {
if config.Level == LevelAuto {
break
}
budget, ok := ConvertLevelToBudget(string(config.Level))
if !ok {
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("unknown level: %s", config.Level))
}
config.Mode = ModeBudget
config.Budget = budget
config.Level = ""
budgetDerivedFromLevel = true
}
case CapabilityLevelOnly:
if config.Mode == ModeBudget {
level, ok := ConvertBudgetToLevel(config.Budget)
if !ok {
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("budget %d cannot be converted to a valid level", config.Budget))
}
// When converting Budget -> Level for level-only models, clamp the derived standard level
// to the nearest supported level. Special values (none/auto) are preserved.
config.Mode = ModeLevel
config.Level = clampLevel(ThinkingLevel(level), modelInfo, toFormat)
config.Budget = 0
}
case CapabilityHybrid:
}
if config.Mode == ModeLevel && config.Level == LevelNone {
config.Mode = ModeNone
config.Budget = 0
config.Level = ""
}
if config.Mode == ModeLevel && config.Level == LevelAuto {
config.Mode = ModeAuto
config.Budget = -1
config.Level = ""
}
if config.Mode == ModeBudget && config.Budget == 0 {
config.Mode = ModeNone
config.Level = ""
}
if len(support.Levels) > 0 && config.Mode == ModeLevel {
if !isLevelSupported(string(config.Level), support.Levels) {
if allowClampUnsupported {
config.Level = clampLevel(config.Level, modelInfo, toFormat)
}
if !isLevelSupported(string(config.Level), support.Levels) {
// User explicitly specified an unsupported level - return error
// (budget-derived levels may be clamped based on source format)
validLevels := normalizeLevels(support.Levels)
message := fmt.Sprintf("level %q not supported, valid levels: %s", strings.ToLower(string(config.Level)), strings.Join(validLevels, ", "))
return nil, NewThinkingError(ErrLevelNotSupported, message)
}
}
}
if strictBudget && config.Mode == ModeBudget && !budgetDerivedFromLevel {
min, max := support.Min, support.Max
if min != 0 || max != 0 {
if config.Budget < min || config.Budget > max || (config.Budget == 0 && !support.ZeroAllowed) {
message := fmt.Sprintf("budget %d out of range [%d,%d]", config.Budget, min, max)
return nil, NewThinkingError(ErrBudgetOutOfRange, message)
}
}
}
// Convert ModeAuto to mid-range if dynamic not allowed
if config.Mode == ModeAuto && !support.DynamicAllowed {
config = convertAutoToMidRange(config, support, toFormat, model)
}
if config.Mode == ModeNone && toFormat == "claude" {
// Claude supports explicit disable via thinking.type="disabled".
// Keep Budget=0 so applier can omit budget_tokens.
config.Budget = 0
config.Level = ""
} else {
switch config.Mode {
case ModeBudget, ModeAuto, ModeNone:
config.Budget = clampBudget(config.Budget, modelInfo, toFormat)
}
// ModeNone with clamped Budget > 0: set Level to lowest for Level-only/Hybrid models
// This ensures Apply layer doesn't need to access support.Levels
if config.Mode == ModeNone && config.Budget > 0 && len(support.Levels) > 0 {
config.Level = ThinkingLevel(support.Levels[0])
}
}
return &config, nil
}
// convertAutoToMidRange converts ModeAuto to a mid-range value when dynamic is not allowed.
//
// This function handles the case where a model does not support dynamic/auto thinking.
// The auto mode is silently converted to a fixed value based on model capability:
// - Level-only models: convert to ModeLevel with LevelMedium
// - Budget models: convert to ModeBudget with mid = (Min + Max) / 2
//
// Logging:
// - Debug level when conversion occurs
// - Fields: original_mode, clamped_to, reason
func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupport, provider, model string) ThinkingConfig {
// For level-only models (has Levels but no Min/Max range), use ModeLevel with medium
if len(support.Levels) > 0 && support.Min == 0 && support.Max == 0 {
config.Mode = ModeLevel
config.Level = LevelMedium
config.Budget = 0
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_mode": "auto",
"clamped_to": string(LevelMedium),
}).Debug("thinking: mode converted, dynamic not allowed, using medium level |")
return config
}
// For budget models, use mid-range budget
mid := (support.Min + support.Max) / 2
if mid <= 0 && support.ZeroAllowed {
config.Mode = ModeNone
config.Budget = 0
} else if mid <= 0 {
config.Mode = ModeBudget
config.Budget = support.Min
} else {
config.Mode = ModeBudget
config.Budget = mid
}
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_mode": "auto",
"clamped_to": config.Budget,
}).Debug("thinking: mode converted, dynamic not allowed |")
return config
}
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh}
// clampLevel clamps the given level to the nearest supported level.
// On tie, prefers the lower level.
func clampLevel(level ThinkingLevel, modelInfo *registry.ModelInfo, provider string) ThinkingLevel {
model := "unknown"
var supported []string
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
if modelInfo.Thinking != nil {
supported = modelInfo.Thinking.Levels
}
}
if len(supported) == 0 || isLevelSupported(string(level), supported) {
return level
}
pos := levelIndex(string(level))
if pos == -1 {
return level
}
bestIdx, bestDist := -1, len(standardLevelOrder)+1
for _, s := range supported {
if idx := levelIndex(strings.TrimSpace(s)); idx != -1 {
if dist := abs(pos - idx); dist < bestDist || (dist == bestDist && idx < bestIdx) {
bestIdx, bestDist = idx, dist
}
}
}
if bestIdx >= 0 {
clamped := standardLevelOrder[bestIdx]
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": string(level),
"clamped_to": string(clamped),
}).Debug("thinking: level clamped |")
return clamped
}
return level
}
// clampBudget clamps a budget value to the model's supported range.
func clampBudget(value int, modelInfo *registry.ModelInfo, provider string) int {
model := "unknown"
support := (*registry.ThinkingSupport)(nil)
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
support = modelInfo.Thinking
}
if support == nil {
return value
}
// Auto value (-1) passes through without clamping.
if value == -1 {
return value
}
min, max := support.Min, support.Max
if value == 0 && !support.ZeroAllowed {
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": value,
"clamped_to": min,
"min": min,
"max": max,
}).Warn("thinking: budget zero not allowed |")
return min
}
// Some models are level-only and do not define numeric budget ranges.
if min == 0 && max == 0 {
return value
}
if value < min {
if value == 0 && support.ZeroAllowed {
return 0
}
logClamp(provider, model, value, min, min, max)
return min
}
if value > max {
logClamp(provider, model, value, max, min, max)
return max
}
return value
}
func isLevelSupported(level string, supported []string) bool {
for _, s := range supported {
if strings.EqualFold(level, strings.TrimSpace(s)) {
return true
}
}
return false
}
func levelIndex(level string) int {
for i, l := range standardLevelOrder {
if strings.EqualFold(level, string(l)) {
return i
}
}
return -1
}
func normalizeLevels(levels []string) []string {
out := make([]string, len(levels))
for i, l := range levels {
out[i] = strings.ToLower(strings.TrimSpace(l))
}
return out
}
func isBudgetBasedProvider(provider string) bool {
switch provider {
case "gemini", "gemini-cli", "antigravity", "claude":
return true
default:
return false
}
}
func isLevelBasedProvider(provider string) bool {
switch provider {
case "openai", "openai-response", "codex":
return true
default:
return false
}
}
func isGeminiFamily(provider string) bool {
switch provider {
case "gemini", "gemini-cli", "antigravity":
return true
default:
return false
}
}
func isSameProviderFamily(from, to string) bool {
if from == to {
return true
}
return isGeminiFamily(from) && isGeminiFamily(to)
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
func logClamp(provider, model string, original, clampedTo, min, max int) {
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": original,
"min": min,
"max": max,
"clamped_to": clampedTo,
}).Debug("thinking: budget clamped |")
}

View File

@@ -7,40 +7,16 @@ package claude
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// deriveSessionID generates a stable session ID from the request.
// Uses the hash of the first user message to identify the conversation.
func deriveSessionID(rawJSON []byte) string {
messages := gjson.GetBytes(rawJSON, "messages")
if !messages.IsArray() {
return ""
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" {
content := msg.Get("content").String()
if content == "" {
// Try to get text from content array
content = msg.Get("content.0.text").String()
}
if content != "" {
h := sha256.Sum256([]byte(content))
return hex.EncodeToString(h[:16])
}
}
}
return ""
}
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
@@ -60,11 +36,9 @@ func deriveSessionID(rawJSON []byte) string {
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
enableThoughtTranslate := true
rawJSON := bytes.Clone(inputRawJSON)
// Derive session ID for signature caching
sessionID := deriveSessionID(rawJSON)
// system instruction
systemInstructionJSON := ""
hasSystemInstruction := false
@@ -122,42 +96,50 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
// Use GetThinkingText to handle wrapped thinking objects
thinkingText := util.GetThinkingText(contentResult)
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
clientSignature = signatureResult.String()
}
thinkingText := thinking.GetThinkingText(contentResult)
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
if len(arrayClientSignatures) == 2 {
if modelName == arrayClientSignatures[0] {
clientSignature = arrayClientSignatures[1]
}
}
}
if cache.HasValidSignature(modelName, clientSignature) {
signature = clientSignature
}
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
if cache.HasValidSignature(modelName, signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature)
isUnsigned := !cache.HasValidSignature(modelName, signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
enableThoughtTranslate = false
continue
}
@@ -205,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(currentMessageThinkingSignature) {
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
@@ -385,12 +367,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/tidwall/gjson"
)
@@ -75,28 +76,39 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
// Valid signature must be at least 50 characters
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
// Pre-cache the signature (simulating a response from the same session)
// The session ID is derived from the first user message hash
// Since there's no user message in this test, we need to add one
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{"type": "text", "text": "Answer"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking block conversion
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
// Check thinking block conversion (now in contents.1 due to user message)
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
if !firstPart.Get("thought").Bool() {
t.Error("thinking block should have thought: true")
}
if firstPart.Get("text").String() != "Let me think..." {
if firstPart.Get("text").String() != thinkingText {
t.Error("thinking text mismatch")
}
if firstPart.Get("thoughtSignature").String() != validSignature {
@@ -227,13 +239,19 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_123",
@@ -245,11 +263,13 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check function call has the signature from the preceding thinking block
part := gjson.Get(outputStr, "request.contents.0.parts.1")
// Check function call has the signature from the preceding thinking block (now in contents.1)
part := gjson.Get(outputStr, "request.contents.1.parts.1")
if part.Get("functionCall.name").String() != "get_weather" {
t.Errorf("Expected functionCall, got %s", part.Raw)
}
@@ -261,24 +281,32 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
// Case: text block followed by thinking block -> should be reordered to thinking first
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Planning..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is the plan."},
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Verify order: Thinking block MUST be first
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
@@ -343,8 +371,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
}
if !thinkingConfig.Get("include_thoughts").Bool() {
t.Error("include_thoughts should be true")
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
} else {
t.Log("thinkingConfig not present - model may not be registered in test registry")
@@ -460,6 +488,9 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
// Last assistant message ends with signed thinking block - should be kept
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Valid thinking..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
@@ -471,12 +502,14 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)

View File

@@ -41,7 +41,6 @@ type Params struct {
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Signature caching support
SessionID string // Session ID derived from request for signature caching
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
}
@@ -70,9 +69,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
SessionID: deriveSessionID(originalRequestRawJSON),
}
}
modelName := gjson.GetBytes(requestRawJSON, "model").String()
params := (*param).(*Params)
@@ -138,14 +137,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
// log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
if params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
@@ -372,7 +371,7 @@ func resolveStopReason(params *Params) string {
// - string: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
_ = requestRawJSON
modelName := gjson.GetBytes(requestRawJSON, "model").String()
root := gjson.ParseBytes(rawJSON)
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
@@ -437,7 +436,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
if thinkingSignature != "" {
block, _ = sjson.Set(block, "signature", thinkingSignature)
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
thinkingBuilder.Reset()

View File

@@ -97,6 +97,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
}`)
@@ -143,7 +144,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, &param)
// Verify signature was cached
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
if cachedSig != validSignature {
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
}
@@ -158,6 +159,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
}`)
@@ -221,13 +223,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
// Process first thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, &param)
params := param.(*Params)
sessionID := params.SessionID
firstThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, &param)
// Verify first signature cached
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
t.Error("First thinking block signature should be cached")
}
@@ -241,76 +242,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, &param)
// Verify second signature cached
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
t.Error("Second thinking block signature should be cached")
}
}
func TestDeriveSessionIDFromRequest(t *testing.T) {
tests := []struct {
name string
input []byte
wantEmpty bool
}{
{
name: "valid user message",
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
wantEmpty: false,
},
{
name: "user message with content array",
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
wantEmpty: false,
},
{
name: "no user message",
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
wantEmpty: true,
},
{
name: "empty messages",
input: []byte(`{"messages": []}`),
wantEmpty: true,
},
{
name: "no messages field",
input: []byte(`{}`),
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deriveSessionID(tt.input)
if tt.wantEmpty && result != "" {
t.Errorf("Expected empty session ID, got '%s'", result)
}
if !tt.wantEmpty && result == "" {
t.Error("Expected non-empty session ID")
}
})
}
}
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
id1 := deriveSessionID(input)
id2 := deriveSessionID(input)
if id1 != id2 {
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
}
}
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
id1 := deriveSessionID(input1)
id2 := deriveSessionID(input2)
if id1 == id2 {
t.Error("Different messages should produce different session IDs")
}
}

View File

@@ -8,6 +8,7 @@ package gemini
import (
"bytes"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -32,12 +33,12 @@ import (
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
@@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
}
}
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
const skipSentinel = "skip_thought_signature_validator"
// Gemini-specific handling for non-Claude models:
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") {
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to remove
var thinkingIndicesToRemove []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Mark thinking blocks for removal
if part.Get("thought").Bool() {
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
}
}
return true
})
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
}
return true
})
// Remove thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
idx := thinkingIndicesToRemove[i]
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
}
}
return true
})
return true
})
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}

View File

@@ -35,66 +35,19 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Reasoning effort -> thinkingBudget/include_thoughts
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")
hasOfficialThinking := re.Exists()
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if re.Exists() {
effort := strings.ToLower(strings.TrimSpace(re.String()))
if util.IsGemini3Model(modelName) {
switch effort {
case "none":
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
case "auto":
includeThoughts := true
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
default:
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
}
}
} else if !util.ModelUsesThinkingLevels(modelName) {
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
}
}
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
// Only apply for models that use numeric budgets, not discrete levels.
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var budget int
if v := tc.Get("thinkingBudget"); v.Exists() {
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
} else if v := tc.Get("thinking_budget"); v.Exists() {
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
}
if v := tc.Get("includeThoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget && budget != 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
// Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens
// This allows Claude Code and other Claude API clients to pass thinking configuration
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) {
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
if effort != "" {
thinkingPath := "request.generationConfig.thinkingConfig"
if effort == "auto" {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true)
} else {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none")
}
}
}
@@ -113,6 +66,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
}
// Candidate count (OpenAI 'n' parameter)
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
if val := n.Int(); val > 1 {
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
}
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
@@ -179,6 +139,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
systemPartIndex := 0
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
@@ -188,16 +149,19 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}
}
@@ -212,7 +176,10 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, item := range items {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
imageURL := item.Get("image_url.url").String()
@@ -256,6 +223,10 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.

View File

@@ -15,6 +15,7 @@ import (
"strings"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -114,15 +115,40 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
}
}
// Include thoughts configuration for reasoning process visibility
// Only apply for models that support thinking and use numeric budgets, not discrete levels.
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
// Check for thinkingBudget first - if present, enable thinking with budget
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() && thinkingBudget.Int() > 0 {
// Translator only does format conversion, ApplyThinking handles model capability validation.
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
default:
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
}
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
budget := int(thinkingBudget.Int())
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
default:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
normalizedBudget := util.NormalizeThinkingBudget(modelName, int(thinkingBudget.Int()))
out, _ = sjson.Set(out, "thinking.budget_tokens", normalizedBudget)
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
// Fallback to include_thoughts if no budget specified
out, _ = sjson.Set(out, "thinking.type", "enabled")
}
}

Some files were not shown because too many files have changed in this diff Show More