Compare commits

...

286 Commits

Author SHA1 Message Date
Luis Pater
82df5bf88a Merge pull request #395 from Xm798/feat/kiro
feat(kiro): add IDC auth code flow, redesign fingerprint and API protocol
2026-02-27 20:50:43 +08:00
Luis Pater
acb1066de8 Merge branch 'router-for-me:main' into main 2026-02-27 20:49:03 +08:00
Luis Pater
27c68f5bb2 fix(auth): replace MarkResult with hook OnResult for result handling 2026-02-27 20:47:46 +08:00
Luis Pater
65a87815e7 Merge pull request #394 from router-for-me/plus
v6.8.31
2026-02-27 16:18:48 +08:00
Luis Pater
b80793ca82 Merge branch 'main' into plus 2026-02-27 16:18:12 +08:00
Luis Pater
601550f238 Merge pull request #393 from cielhaidir/main
feat(kiro): add new Kiro models definition
2026-02-27 16:17:05 +08:00
Luis Pater
41b1cf2273 Merge pull request #1734 from huangusaki/main
feat(registry): add gemini-3.1-flash-image support
2026-02-27 16:12:05 +08:00
huang_usaki
3b4f9f43db feat(registry): add gemini-3.1-flash-image support 2026-02-27 10:20:46 +08:00
“cielhaidir”
37a09ecb23 feat(kiro): add new Kiro models definition 2026-02-27 10:18:59 +08:00
Luis Pater
0da34d3c2d Merge pull request #1668 from lyd123qw2008/fix/codex-usage-limit-retry-after
fix(codex): honor usage_limit_reached resets_at for retry_after
2026-02-27 06:01:44 +08:00
Luis Pater
74bf7eda8f Merge pull request #1686 from lyd123qw2008/fix/auth-refresh-concurrency-limit
fix(auth): limit auto-refresh concurrency to prevent refresh storms
2026-02-27 05:59:20 +08:00
Cyrus
9032042cfa feat(kiro): add Sonnet 4.6 model alias
- Add kiro-claude-sonnet-4-6 alias mapping to claude-sonnet-4-6
2026-02-27 01:02:21 +08:00
Cyrus
030bf5e6c7 feat(kiro): add IDC auth and endpoint improvements, redesign fingerprint system
- Add IAM Identity Center (IDC) authentication with CLI flags (--kiro-idc-login, --kiro-idc-start-url, --kiro-idc-region) and login flow
- Add ProfileArn auto-fetching in Execute/ExecuteStream for imported IDC accounts
- Simplify endpoint preference with map-based alias lookup and getAuthValue helper
- Redesign fingerprint as global singleton with external config and per-account deterministic generation
- Add StartURL and FingerprintConfig fields to Kiro config
- Add AgentContinuationID/AgentTaskType support in Kiro translators
- Add comprehensive tests for executor, fingerprint, SSO OIDC, and AWS helpers
- Add CLI login documentation to README
2026-02-27 00:58:03 +08:00
Luis Pater
d3100085b0 Merge pull request #392 from router-for-me/plus
v6.8.30
2026-02-26 23:16:26 +08:00
Luis Pater
f481d25133 Merge branch 'main' into plus 2026-02-26 23:16:17 +08:00
Luis Pater
8c6c90da74 fix(registry): clean up outdated model definitions in static data 2026-02-26 23:12:40 +08:00
Luis Pater
24bcfd9c03 Merge pull request #1699 from 123hi123/fix/antigravity-primary-model-fallback
fix(antigravity): keep primary model list and backfill empty auths
2026-02-26 04:28:29 +08:00
Luis Pater
816fb4c5da Merge pull request #1682 from sususu98/fix/tool-result-image-parts
fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
2026-02-25 23:14:35 +08:00
Luis Pater
c1bb77c7c9 Merge pull request #291 from howarddong711/feat/copilot-email-name
feat(copilot): fetch and persist user email and display name on login
2026-02-25 22:23:25 +08:00
Luis Pater
6bcac3a55a Merge branch 'router-for-me:main' into main 2026-02-25 22:21:31 +08:00
Howard Dong
fc346f4537 fix(copilot): add username fallback and consistent file name prefix
- Add 'github-user' fallback in WaitForAuthorization when FetchUserInfo
  returns empty Login (fixes malformed 'github-copilot-.json' filenames)
- Standardize Web API file name to 'github-copilot-<user>.json' to match
  CLI path convention (was 'github-<user>.json')

Addresses Gemini Code Assist review comments on PR #291.
2026-02-25 17:17:51 +08:00
Howard Dong
43e531a3b6 feat(copilot): fetch and persist user email and display name on login
- Expand OAuth scope to include read:user for full profile access
- Add GitHubUserInfo struct with Login, Email, Name fields
- Update FetchUserInfo to return complete user profile
- Add Email and Name fields to CopilotTokenStorage and CopilotAuthBundle
- Fix provider string bug: 'github' -> 'github-copilot' in auth_files.go
- Fix semantic bug: email field was storing username
- Update Label to prefer email over username in both CLI and Web API paths
- Add 9 unit tests covering new functionality
2026-02-25 17:09:40 +08:00
Luis Pater
d24ea4ce2a Merge pull request #1664 from ciberponk/pr/responses-compaction-compat
feat: add codex responses compatibility for compaction payloads
2026-02-25 01:21:59 +08:00
Luis Pater
2c30c981ae Merge pull request #1687 from lyd123qw2008/fix/codex-refresh-token-reused-no-retry
fix(codex): stop retrying refresh_token_reused errors
2026-02-25 01:19:30 +08:00
Luis Pater
aa1da8a858 Merge pull request #1685 from lyd123qw2008/fix/auth-auto-refresh-interval
fix(auth): respect configured auto-refresh interval
2026-02-25 01:13:47 +08:00
Luis Pater
f1e9a787d7 Merge pull request #1676 from piexian/feat/qwen-quota-handling-clean
feat(qwen): add rate limiting and quota error handling
2026-02-25 01:07:55 +08:00
Luis Pater
4eeec297de Merge pull request #288 from router-for-me/plus
v6.8.27
2026-02-25 01:04:57 +08:00
Luis Pater
77cc4ce3a0 Merge branch 'main' into plus 2026-02-25 01:04:15 +08:00
Luis Pater
37dfea1d3f Merge pull request #287 from possible055/main
fix(kiro): support OR-group field matching in truncation detector
2026-02-25 01:02:49 +08:00
Luis Pater
e6626c672a Merge pull request #269 from ClubWeGo/fix/filterOrphanedToolResults
fix: filter out orphaned tool results from history and current context
2026-02-25 01:02:11 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
apparition
c785c1a3ca fix(kiro): support OR-group field matching in truncation detector
- Change RequiredFieldsByTool value type from []string to [][]string
- Outer slice = AND (all groups required); inner slice = OR (any one satisfies)
- Fix Bash entry to accept "cmd" or "command", resolving soft-truncation loop
- Update findMissingRequiredFields logic and inline docs accordingly
2026-02-24 22:48:05 +08:00
comalot
514ae341c8 fix(antigravity): deep copy cached model metadata 2026-02-24 20:14:01 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
comalot
8ce07f38dd fix(antigravity): keep primary model list and backfill empty auths 2026-02-24 16:16:44 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
lyd123qw2008
3b3e0d1141 test(codex): log non-retryable refresh error and cover single-attempt behavior 2026-02-23 22:41:33 +08:00
lyd123qw2008
7acd428507 fix(codex): stop retrying refresh_token_reused errors 2026-02-23 22:31:30 +08:00
lyd123qw2008
0aaf177640 fix(auth): limit auto-refresh concurrency to prevent refresh storms 2026-02-23 22:28:41 +08:00
lyd123qw2008
450d1227bd fix(auth): respect configured auto-refresh interval 2026-02-23 22:07:50 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
Darley
6e634fe3f9 fix: filter out orphaned tool results from history and current context 2026-02-23 14:33:59 +08:00
sususu98
4e26182d14 fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
Move base64 image data from Claude tool_result into functionResponse.parts
as inlineData instead of outer sibling parts, preventing context bloat.
Unify all inlineData field naming to camelCase mimeType across Claude,
OpenAI, and Gemini translators. Add comprehensive edge case tests and
Gemini-side regression test for functionResponse.parts preservation.
2026-02-23 13:38:21 +08:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
piexian
3b421c8181 feat(qwen): add rate limiting and quota error handling
- Add 60 requests/minute rate limiting per credential using sliding window
- Detect insufficient_quota errors and set cooldown until next day (Beijing time)
- Map quota errors (HTTP 403/429) to 429 with retryAfter for conductor integration
- Cache Beijing timezone at package level to avoid repeated syscalls
- Add redactAuthID function to protect credentials in logs
- Extract wrapQwenError helper to consolidate error handling
2026-02-23 00:38:46 +08:00
Luis Pater
21d2329947 Merge pull request #261 from router-for-me/plus
v6.8.26
2026-02-23 00:15:36 +08:00
Luis Pater
0993413bab Merge branch 'main' into plus 2026-02-23 00:15:22 +08:00
Luis Pater
713388dd7b Fixed: #1675
fix(gemini): add model definitions for Gemini 3.1 Pro High and Image
2026-02-23 00:12:57 +08:00
Luis Pater
e6c7af0fa9 Merge pull request #1522 from soilSpoon/feature/canceled
feature(proxy): Adds special handling for client cancellations in proxy error handler
2026-02-22 22:02:59 +08:00
Luis Pater
837aa6e3aa Merge branch 'router-for-me:main' into main 2026-02-22 21:52:53 +08:00
Luis Pater
d210be06c2 fix(gemini): update min Thinking value and add Gemini 3.1 Pro Preview model definition 2026-02-22 21:51:32 +08:00
fan
afc8a0f9be refactor: simplify context_management compatibility handling 2026-02-21 22:20:48 +08:00
Luis Pater
af8e9ef458 Merge branch 'router-for-me:main' into main 2026-02-21 21:09:52 +08:00
Luis Pater
cec6f993ad Merge pull request #256 from kavore/fix/oauth-copilot-claude-aliases
fix: add default copilot claude model aliases for oauth routing
2026-02-21 21:09:43 +08:00
Luis Pater
950de29f48 Merge pull request #255 from ladeng07/main
feat(registry): add GPT-4o model variants for GitHub Copilot
2026-02-21 21:09:06 +08:00
Luis Pater
d6ec33e8e1 Merge pull request #1662 from matchch/contribute/cache-user-id
feat: add cache-user-id toggle for Claude cloaking
2026-02-21 20:51:30 +08:00
Luis Pater
081cfe806e fix(gemini): correct Created timestamps for Gemini 3.1 Pro Preview model definitions 2026-02-21 20:47:47 +08:00
hkfires
c1c62a6c04 feat(gemini): add Gemini 3.1 Pro Preview model definitions 2026-02-21 20:42:29 +08:00
lyd123qw2008
a99522224f refactor(codex): make retry-after parsing deterministic for tests 2026-02-21 14:13:38 +08:00
lyd123qw2008
f5d46b9ca2 fix(codex): honor usage_limit_reached resets_at for retry_after 2026-02-21 13:50:23 +08:00
ciberponk
d693d7993b feat: support responses compaction payload compatibility for codex translator 2026-02-21 12:56:10 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
matchch
2fdf5d2793 feat: add cache-user-id toggle for Claude cloaking
Default to generating a fresh random user_id per request instead of
reusing cached IDs. Add cache-user-id config option to opt in to the
previous caching behavior.

- Add CacheUserID field to CloakConfig
- Extract user_id cache logic to dedicated file
- Generate fresh user_id by default, cache only when enabled
- Add tests for both paths
2026-02-21 12:31:20 +08:00
kavore
b3da00d2ed fix: add default copilot claude model aliases for oauth routing 2026-02-20 21:59:21 +03:00
LMark
740277a9f2 refactor(registry): deduplicate GitHub Copilot GPT-4o model definitions 2026-02-21 02:32:06 +08:00
LMark
f91807b6b9 Add GPT-4o model variants while keeping Gemini 3.1 Pro preview 2026-02-21 01:41:01 +08:00
Luis Pater
57d18bb226 Merge branch 'router-for-me:main' into main 2026-02-20 22:42:01 +08:00
Luis Pater
10b9c6cb8a Merge pull request #252 from DragonBaiMo/fix/kiro-thinking-stream-dedup
fix(kiro): stop duplicated thinking on OpenAI and preserve Claude multi-turn thinking
2026-02-20 22:41:48 +08:00
Luis Pater
b24786f8a7 Merge pull request #250 from TonyRL/feat/copilot-gemini-3.1
feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider
2026-02-20 22:40:41 +08:00
Luis Pater
7b0eb41ebc Merge pull request #1660 from Grivn/fix/claude-token-url
fix(claude): use api.anthropic.com for OAuth token exchange
2026-02-20 21:52:08 +08:00
DragonBaiMo
70949929db fix(kiro): deduplicate thinking stream emission 2026-02-20 20:34:40 +08:00
DragonBaiMo
7c9c89dace fix(kiro): keep thinking enabled across request formats 2026-02-20 20:34:40 +08:00
Grivn
ef5901c81b fix(claude): use api.anthropic.com for OAuth token exchange
console.anthropic.com is now protected by a Cloudflare managed challenge
that blocks all non-browser POST requests to /v1/oauth/token, causing
`-claude-login` to fail with a 403 error.

Switch to api.anthropic.com which hosts the same OAuth token endpoint
without the Cloudflare managed challenge.

Fixes #1659

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 20:11:27 +08:00
Luis Pater
d4829c82f7 Merge pull request #1652 from thebtf/fix/claude-translator-arguments
fix(translator): handle tool call arguments in codex→claude streaming translator
2026-02-20 19:50:20 +08:00
Luis Pater
a5f4166a9b Merge pull request #1644 from possible055/main
feat: add Gemini 3.1 Pro Preview model definition
2026-02-20 19:44:59 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Tony
f2b1ec4f9e feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider 2026-02-20 04:23:42 +08:00
Kirill Turanskiy
1cc21cc45b fix: prevent duplicate function call arguments when delta events precede done
Non-spark codex models (gpt-5.3-codex, gpt-5.2-codex) stream function call
arguments via multiple delta events followed by a done event. The done handler
unconditionally emitted the full arguments, duplicating what deltas already
streamed. This produced invalid double JSON that Claude Code couldn't parse,
causing tool calls to fail with missing parameters and infinite retry loops.

Add HasReceivedArgumentsDelta flag to track whether delta events were received.
The done handler now only emits arguments when no deltas preceded it (spark
models), while delta-based streaming continues to work for non-spark models.
2026-02-19 23:18:14 +03:00
Kirill Turanskiy
07cf616e2b fix: handle response.function_call_arguments.done in codex→claude streaming translator
Some Codex models (e.g. gpt-5.3-codex-spark) send function call arguments
in a single "done" event without preceding "delta" events. The streaming
translator only handled "delta" events, causing tool call arguments to be
lost — resulting in empty tool inputs and infinite retry loops in clients
like Claude Code.

Emit the full arguments from the "done" event as a single input_json_delta
so downstream clients receive the complete tool input.
2026-02-19 23:18:14 +03:00
Luis Pater
2b8c466e88 refactor(executor, handlers): replace channel-based streams with StreamResult for consistency
- Updated `ExecuteStream` functions in executors to use `StreamResult` instead of channels.
- Enhanced upstream header handling in OpenAI handlers.
- Improved maintainability and alignment across executors and handlers.
2026-02-19 22:07:14 +08:00
Luis Pater
ca2174ea48 Merge pull request #249 from router-for-me/plus
v6.8.22
2026-02-19 21:58:42 +08:00
Luis Pater
c09fb2a79d Merge branch 'main' into plus 2026-02-19 21:58:04 +08:00
Luis Pater
4445a165e9 test(handlers): add tests for passthrough headers behavior in WriteErrorResponse 2026-02-19 21:49:44 +08:00
Luis Pater
e92e2af71a Merge branch 'codex/pr-1626' into dev 2026-02-19 21:33:23 +08:00
Luis Pater
a6bdd9a652 feat: add passthrough headers configuration
- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers.
- Updated handlers to respect the passthrough headers setting.
- Added tests to verify behavior when passthrough is enabled or disabled.
2026-02-19 21:31:29 +08:00
Luis Pater
349a6349b3 Merge pull request #1645 from tinyc0der/fix/antigravity-tool-result-json
fix(antigravity): prevent invalid JSON when tool_result has no content
2026-02-19 21:01:25 +08:00
TinyCoder
00822770ec fix(antigravity): prevent invalid JSON when tool_result has no content
sjson.SetRaw with an empty string produces malformed JSON (e.g. "result":}).
This happens when a Claude tool_result block has no content field, causing
functionResponseResult.Raw to be "". Guard against this by falling back to
sjson.Set with an empty string only when .Raw is empty.
2026-02-19 17:08:39 +07:00
apparition
1a0ceda0fc feat: add Gemini 3.1 Pro Preview model definition 2026-02-19 17:43:08 +08:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
Luis Pater
72add453d2 docs: add OmniRoute to README 2026-02-19 13:23:25 +08:00
Luis Pater
2789396435 fix: ensure connection-scoped headers are filtered in upstream requests
- Added `connectionScopedHeaders` utility to respect "Connection" header directives.
- Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically.
- Refactored and tested upstream header filtering with additional validations.
- Adjusted upstream header handling during retries to replace headers safely.
2026-02-19 13:19:10 +08:00
Luis Pater
61da7bd981 Merge PR #1626 into codex/pr-1626 2026-02-19 04:49:14 +08:00
Luis Pater
ae4c502792 Merge pull request #248 from router-for-me/plus
v6.8.21
2026-02-19 04:42:44 +08:00
Luis Pater
ec6068060b Merge branch 'main' into plus 2026-02-19 04:42:35 +08:00
Luis Pater
ecb01d3dcd Merge pull request #244 from PancakeZik/feat/sonnet-4-6
feat: add Claude Sonnet 4.6 model support for Kiro provider
2026-02-19 04:31:20 +08:00
Luis Pater
22c0c00bd4 Merge branch 'main' into feat/sonnet-4-6 2026-02-19 04:31:07 +08:00
Luis Pater
9eb3e7a6c4 Merge pull request #243 from gl11tchy/feat/claude-sonnet-4-6
feat(registry): add Claude Sonnet 4.6 model definitions
2026-02-19 04:29:39 +08:00
Luis Pater
357c191510 Merge pull request #242 from ultraplan-bit/main
Improve Copilot provider based on ericc-ch/copilot-api comparison
2026-02-19 04:27:02 +08:00
Luis Pater
5db244af76 Merge pull request #240 from TonyRL/feat/copilot-sonnet-4.6
feat(registry): add Sonnet 4.6 to GitHub Copilot provider
2026-02-19 04:26:28 +08:00
Luis Pater
dc375d1b74 Merge pull request #239 from TonyRL/feat/copilot-codex-5.3
feat(registry): add GPT-5.3 Codex to GitHub Copilot provider
2026-02-19 04:25:25 +08:00
Luis Pater
9c040445af Merge pull request #1635 from thebtf/fix/openai-translator-tool-streaming
fix: handle tool call argument streaming in Codex→OpenAI translator
2026-02-19 04:22:12 +08:00
Luis Pater
fff866424e Merge pull request #1628 from thebtf/fix/masquerading-headers
fix: update Claude masquerading headers and configurable defaults
2026-02-19 04:19:59 +08:00
Luis Pater
2d12becfd6 Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping
fix: clamp reasoning_effort to valid OpenAI-format values
2026-02-19 04:15:19 +08:00
Luis Pater
252f7e0751 Merge pull request #1625 from thebtf/feat/tool-prefix-config
feat: add per-auth tool_prefix_disabled option
2026-02-19 04:07:22 +08:00
Luis Pater
b2b17528cb Merge branch 'pr-1624' into dev
# Conflicts:
#	internal/runtime/executor/claude_executor.go
#	internal/runtime/executor/claude_executor_test.go
2026-02-19 04:05:04 +08:00
Luis Pater
55f938164b Merge pull request #1618 from alexey-yanchenko/fix/completions-usage
Fix empty usage in /v1/completions
2026-02-19 03:57:11 +08:00
Luis Pater
76294f0c59 Merge pull request #1608 from thebtf/fix/tool-reference-proxy-prefix-mainline
fix: add proxy_ prefix handling for tool_reference content blocks
2026-02-19 03:50:34 +08:00
Luis Pater
2bcee78c6e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:19:18 +08:00
Luis Pater
7fe8246a9f Merge branch 'tui' into dev 2026-02-19 03:18:24 +08:00
Luis Pater
93fe58e31e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:18:08 +08:00
Luis Pater
e5b5dc870f chore(executor): remove unused Openai-Beta header from Codex executor 2026-02-19 02:19:48 +08:00
Luis Pater
a54877c023 Merge branch 'dev' 2026-02-19 02:03:41 +08:00
Luis Pater
bb86a0c0c4 feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage.
- Added WebSocket-based Codex executor to support Responses API upgrade.
- Updated middleware logic to selectively capture request bodies for memory efficiency.
- Enhanced Codex configuration handling with new WebSocket attributes.
2026-02-19 01:57:02 +08:00
Kirill Turanskiy
5fa23c7f41 fix: handle tool call argument streaming in Codex→OpenAI translator
The OpenAI Chat Completions translator was silently dropping
response.function_call_arguments.delta and
response.function_call_arguments.done Codex SSE events, meaning
tool call arguments were never streamed incrementally to clients.

Add proper handling mirroring the proven Claude translator pattern:

- response.output_item.added: announce tool call (id, name, empty args)
- response.function_call_arguments.delta: stream argument chunks
- response.function_call_arguments.done: emit full args if no deltas
- response.output_item.done: defensive fallback for backward compat

State tracking via HasReceivedArgumentsDelta and HasToolCallAnnounced
ensures no duplicate argument emission and correct behavior for models
like codex-spark that skip delta events entirely.
2026-02-18 19:09:05 +03:00
gl11tchy
f9a09b7f23 style: sort model entries per review feedback 2026-02-18 15:06:28 +00:00
Joao
b0cde626fe feat: add Claude Sonnet 4.6 model support for Kiro provider 2026-02-18 13:51:23 +00:00
gl11tchy
e42ef9a95d feat(registry): add Claude Sonnet 4.6 model definitions
Add claude-sonnet-4-6 to:
- Claude OAuth provider (model_definitions_static_data.go)
- Antigravity model config (thinking + non-thinking entries)
- GitHub Copilot provider (model_definitions.go)

Ref: https://docs.anthropic.com/en/docs/about-claude/models
2026-02-18 13:43:22 +00:00
ultraplan-bit
abf1629ec7 Merge branch 'main' of https://github.com/ultraplan-bit/CLIProxyAPIPlus 2026-02-18 08:56:06 +08:00
Kirill Turanskiy
73dc0b10b8 fix: update Claude masquerading headers and make them configurable
Update hardcoded X-Stainless-* and User-Agent defaults to match
Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (verified via
diagnostic proxy capture 2026-02-17).

Changes:
- X-Stainless-Os/Arch: dynamic via runtime.GOOS/GOARCH
- X-Stainless-Package-Version: 0.55.1 → 0.74.0
- X-Stainless-Timeout: 60 → 600
- User-Agent: claude-cli/1.0.83 (external, cli) → claude-cli/2.1.44 (external, sdk-cli)

Add claude-header-defaults config section so values can be updated
without recompilation when Claude Code releases new versions.
2026-02-18 03:38:51 +03:00
Kirill Turanskiy
2ea95266e3 fix: clamp reasoning_effort to valid OpenAI-format values
CPA-internal thinking levels like 'xhigh' and 'minimal' are not accepted
by OpenAI-format providers (MiniMax, etc.). The OpenAI applier now maps
non-standard levels to the nearest valid reasoning_effort value before
writing to the request body:

  xhigh   → high
  minimal → low
  auto    → medium
2026-02-18 03:36:42 +03:00
Tony
922d4141c0 feat(registry): add Sonnet 4.6 to GitHub Copilot provider 2026-02-18 05:17:23 +08:00
Kirill Turanskiy
1f8f198c45 feat: passthrough upstream response headers to clients
CPA previously stripped ALL response headers from upstream AI provider
APIs, preventing clients from seeing rate-limit info, request IDs,
server-timing and other useful headers.

Changes:
- Add Headers field to Response and StreamResult structs
- Add FilterUpstreamHeaders helper (hop-by-hop + security denylist)
- Add WriteUpstreamHeaders helper (respects CPA-set headers)
- ExecuteWithAuthManager/ExecuteCountWithAuthManager now return headers
- ExecuteStreamWithAuthManager returns headers from initial connection
- All 11 provider executors populate Response.Headers
- All handler call sites write filtered upstream headers before response

Filtered headers (not forwarded):
- RFC 7230 hop-by-hop: Connection, Transfer-Encoding, Keep-Alive, etc.
- Security: Set-Cookie
- CPA-managed: Content-Length, Content-Encoding
2026-02-18 00:16:22 +03:00
Tony
c55275342c feat(registry): add GPT-5.3 Codex to GitHub Copilot provider 2026-02-18 03:04:27 +08:00
Kirill Turanskiy
9261b0c20b feat: add per-auth tool_prefix_disabled option
Allow disabling the proxy_ tool name prefix on a per-account basis.
Users who route their own Anthropic account through CPA can set
"tool_prefix_disabled": true in their OAuth auth JSON to send tool
names unchanged to Anthropic.

Default behavior is fully preserved — prefix is applied unless
explicitly disabled.

Changes:
- Add ToolPrefixDisabled() accessor to Auth (reads metadata key
  "tool_prefix_disabled" or "tool-prefix-disabled")
- Gate all 6 prefix apply/strip points with the new flag
- Add unit tests for the accessor
2026-02-17 21:48:19 +03:00
Kirill Turanskiy
7cc725496e fix: skip proxy_ prefix for built-in tools in message history
The proxy_ prefix logic correctly skips built-in tools (those with a
non-empty "type" field) in tools[] definitions but does not skip them
in messages[].content[] tool_use blocks or tool_choice. This causes
web_search in conversation history to become proxy_web_search, which
Anthropic does not recognize.

Fix: collect built-in tool names from tools[] into a set and also
maintain a hardcoded fallback set (web_search, code_execution,
text_editor, computer) for cases where the built-in tool appears in
history but not in the current request's tools[] array. Skip prefixing
in messages and tool_choice when name matches a built-in.
2026-02-17 21:42:32 +03:00
ultraplan-bit
5726a99c80 Improve Copilot provider based on ericc-ch/copilot-api comparison
- Fix X-Initiator detection: check for any assistant/tool role
  in messages instead of only the last message role, matching
  the correct agent detection for multi-turn tool conversations
- Add x-github-api-version: 2025-04-01 header for API compatibility
- Support Business/Enterprise accounts by using Endpoints.API from
  the Copilot token response instead of hardcoded base URL
- Fix Responses API vision detection: detect vision content before
  input normalization removes the messages array
- Add 8 test cases covering the above fixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:11:17 +08:00
ultraplan-bit
b5756bf729 Fix Copilot 0x model incorrectly consuming premium requests
Change Openai-Intent header from "conversation-edits" to
"conversation-panel" to avoid triggering GitHub's premium
execution path, which caused included models (0x multiplier)
to be billed as premium requests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 21:17:18 +08:00
Alexey Yanchenko
709d999f9f Add usage to /v1/completions 2026-02-17 17:21:03 +07:00
Kirill Turanskiy
24c18614f0 fix: skip built-in tools in tool_reference prefix + refactor to switch
- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
2026-02-16 19:37:11 +03:00
Kirill Turanskiy
603f06a762 fix: handle tool_reference nested inside tool_result.content[]
tool_reference blocks can appear nested inside tool_result.content[]
arrays, not just at the top level of messages[].content[]. The prefix
logic now iterates into tool_result blocks with array content to find
and prefix/strip nested tool_reference.tool_name fields.
2026-02-16 19:06:24 +03:00
Kirill Turanskiy
98f0a3e3bd fix: add proxy_ prefix handling for tool_reference content blocks (#1)
applyClaudeToolPrefix, stripClaudeToolPrefixFromResponse, and
stripClaudeToolPrefixFromStreamLine now handle "tool_reference" blocks
(field "tool_name") in addition to "tool_use" blocks (field "name").

Without this fix, tool_reference blocks in conversation history retain
their original unprefixed names while tool definitions carry the proxy_
prefix, causing Anthropic API 400 errors: "Tool reference 'X' not found
in available tools."

Co-authored-by: Kirill Turanskiy <kt@novamedia.ru>
2026-02-16 19:06:24 +03:00
Luis Pater
e186ccb0d4 Merge pull request #234 from detroittommy879/feature/add-kilocode-provider
Add Kilo Code provider with dynamic model fetching
2026-02-16 23:54:29 +08:00
Luis Pater
8fc0b08b70 Merge pull request #233 from ultraplan-bit/fix/copilot-codex-responses-translation
Fix Copilot codex model Responses API translation for Claude Code
2026-02-16 23:51:42 +08:00
Luis Pater
52a257dc24 Merge pull request #237 from router-for-me/plus
v6.8.18
2026-02-16 23:50:00 +08:00
Luis Pater
a12d907f55 Merge branch 'main' into plus 2026-02-16 23:49:50 +08:00
Luis Pater
453aaf8774 chore(runtime): update Qwen executor user agent and headers for compatibility with new runtime standards 2026-02-16 23:29:47 +08:00
Supra4E8C
1b1ab1fb9b Merge pull request #1606 from router-for-me/add-qwen-3.5
feat(registry): add Qwen 3.5 Plus model definitions
2026-02-16 23:10:53 +08:00
Supra4E8C
a9d0bb72da feat(registry): add Qwen 3.5 Plus model definitions 2026-02-16 22:55:37 +08:00
DetroitTommy
d328e54e4b refactor(kilo): address code review suggestions for robustness 2026-02-15 17:26:29 -05:00
DetroitTommy
5a7932cba4 Added Kilo Code as a provider, with auth. It fetches the free models, tested them (works), for paid models someone will have to experiment so only the free ones are known to work 2026-02-15 14:54:20 -05:00
DetroitTommy
1dbeb0827a added kilocode auth, needs adjusting 2026-02-15 13:44:26 -05:00
lhpqaq
2c8821891c fix(tui): update with review 2026-02-16 00:24:25 +08:00
haopeng
0a2555b0f3 Update internal/tui/auth_tab.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-16 00:11:31 +08:00
lhpqaq
020df41efe chore(tui): update readme, fix usage 2026-02-16 00:04:04 +08:00
ultraplan-bit
f8f8cf17ce Fix Copilot codex model Responses API translation for Claude Code
- Add response.function_call_arguments.delta handler for tool call parameters
- Rewrite normalizeGitHubCopilotResponsesInput to produce structured input
  array (message/function_call/function_call_output) instead of flattened
  text, fixing infinite loop in multi-turn tool-use conversations
- Skip flattenAssistantContent for messages containing tool_use blocks,
  preventing function_call items from being destroyed
- Add reasoning/thinking stream & non-stream support
- Fix stop_reason mapping (max_tokens/stop) and cached token reporting
- Update test to match new array-based input format

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:04:45 +08:00
lhpqaq
f31f7f701a feat(tui): add i18n 2026-02-15 15:42:59 +08:00
Supra4E8C
b5fe78eb70 Merge pull request #1597 from router-for-me/kimi-fix
feat(registry): add support for 'kimi' channel in model definitions
2026-02-15 15:35:17 +08:00
Supra4E8C
d1f667cf8d feat(registry): add support for 'kimi' channel in model definitions 2026-02-15 15:21:33 +08:00
lhpqaq
54ad7c1b6b feat(tui): add manager tui 2026-02-15 14:52:40 +08:00
Luis Pater
d560c20c26 Merge branch 'router-for-me:main' into main 2026-02-15 14:49:13 +08:00
Luis Pater
5abeca1f9e Merge pull request #231 from ChrAlpha/main
feat(models): add Thinking support to GitHub Copilot models
2026-02-15 14:48:04 +08:00
Luis Pater
294eac3a88 Merge branch 'main' into main 2026-02-15 14:47:52 +08:00
Luis Pater
a31104020c Merge pull request #230 from ultraplan-bit/main
fix(copilot): forward Claude-format tools to Copilot Responses API
2026-02-15 14:45:27 +08:00
Luis Pater
65bec4d734 Merge pull request #229 from Buywatermelon/fix/issue-222-kiro-alias-deletion
fix: preserve explicitly deleted kiro aliases across config reload
2026-02-15 14:42:42 +08:00
Luis Pater
edb2993838 Merge pull request #228 from xilu0/fix/antigravity-fetch-models-logging
fix(antigravity): add warn-level logging to silent failure paths in FetchAntigravityModels
2026-02-15 14:42:13 +08:00
Luis Pater
c0d8e0dec7 Merge pull request #226 from Skyuno/refactor/websearch-alignment
refactor(kiro): Kiro Web Search Logic & Executor Alignment
2026-02-15 14:41:46 +08:00
ChrAlpha
795da13d5d feat(tests): add comprehensive GitHub Copilot tests for reasoning effort levels 2026-02-15 06:40:52 +00:00
Luis Pater
55789df275 chore(docker): update Go base image to 1.26-alpine 2026-02-15 14:26:44 +08:00
ChrAlpha
9e652a3540 fix(github-copilot): remove 'xhigh' level from Thinking support 2026-02-15 06:12:08 +00:00
Luis Pater
46a6782065 refactor(all): replace manual pointer assignments with new to enhance code readability and maintainability 2026-02-15 14:10:10 +08:00
Luis Pater
c359f61859 fix(auth): normalize Gemini credential file prefix for consistency 2026-02-15 13:59:33 +08:00
Luis Pater
908c8eab5b Merge pull request #1543 from sususu98/feat/gemini-cli-google-one
feat(gemini-cli): add Google One login and improve auto-discovery
2026-02-15 13:58:21 +08:00
Luis Pater
f5f2c69233 Merge pull request #1595 from alexey-yanchenko/feature/cache-usage-from-codex-to-chat-completions
Pass cache usage from codex to openai chat completions
2026-02-15 13:56:46 +08:00
Alexey Yanchenko
63d4de5eea Pass cache usage from codex to openai chat completions 2026-02-15 12:04:15 +07:00
ChrAlpha
af15083496 feat(models): add Thinking support to GitHub Copilot models
Enhance the model definitions by introducing Thinking support with various levels for each model.
2026-02-15 03:16:08 +00:00
ultraplan-bit
c4722e42b1 fix(copilot): forward Claude-format tools to Copilot Responses API
The normalizeGitHubCopilotResponsesTools filter required type="function",
which dropped Claude-format tools (no type field, uses input_schema).
Relax the filter to accept tools without a type field and map input_schema
to parameters so tools are correctly sent to the upstream API.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:58:15 +08:00
Dave
f9a991365f Update internal/runtime/executor/antigravity_executor.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-14 10:56:36 +08:00
y
6df16bedba fix: preserve explicitly deleted kiro aliases across config reload (#222)
The delete handler now sets the channel value to nil instead of removing
the map key, and the sanitization loop preserves nil/empty channel entries
as 'disabled' markers.  This prevents SanitizeOAuthModelAlias from
re-injecting default kiro aliases after a user explicitly deletes them
through the management API.
2026-02-14 09:40:05 +08:00
Skyuno
632a2fd2f2 refactor: align GenerateSearchIndicatorEvents return type with other event builders
Change GenerateSearchIndicatorEvents to return [][]byte instead of []sseEvent
for consistency with BuildFallbackTextEvents and other event building functions.

Benefits:
- Consistent API across all event generation functions
- Eliminates intermediate sseEvent type conversion in caller
- Simplifies usage by returning ready-to-send SSE byte slices

This addresses the code quality feedback from PR #226 review.
2026-02-13 22:04:09 +08:00
Skyuno
5626637fbd security: remove query content from web search logs to prevent PII leakage
- Remove search query from iteration logs (Info level)
- Remove query and toolUseId from analysis logs (Info level)
- Remove query from non-stream result logs (Info level)
- Remove query from tool injection logs (Info level)
- Remove query from tool_use detection logs (Debug level)

This addresses the security concern raised in PR #226 review about
potential PII exposure in search query logs.
2026-02-13 22:04:09 +08:00
Skyuno
2db89211a9 kiro: use payloadRequestedModel for response model name
Align Kiro executor with all other executors (Claude, Gemini, OpenAI,
etc.) by using payloadRequestedModel(opts, req.Model) instead of
req.Model when constructing response model names.

This ensures model aliases are correctly reflected in responses:
- Execute: BuildClaudeResponse + TranslateNonStream
- ExecuteStream: streamToChannel
- handleWebSearchStream: BuildClaudeMessageStartEvent
- handleWebSearch: via executeNonStreamFallback (automatic)

Previously Kiro was the only executor using req.Model directly,
which exposed internal routed names instead of the user's alias.
2026-02-13 22:04:09 +08:00
Skyuno
587371eb14 refactor: align web search with executor layer patterns
Consolidate web search handler, SSE event generation, stream analysis,
and MCP HTTP I/O into the executor layer. Merge the separate
kiro_websearch_handler.go back into kiro_executor.go to align with
the single-file-per-executor convention. Translator retains only pure
data types, detection, and payload transformation.

Key changes:
- Move SSE construction (search indicators, fallback text, message_start)
  from translator to executor, consistent with streamToChannel pattern
- Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription)
  from translator to executor alongside other HTTP I/O
- Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication)
- Centralize MCP endpoint URL via BuildMcpEndpoint in translator
- Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache
- Thread context.Context through MCP HTTP calls for cancellation support
- Thread usage reporter through all web search API call paths
- Add token expiry pre-check before MCP/GAR calls
- Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic,
  ContainsWebSearchTool, StripWebSearchTool)
2026-02-13 22:04:09 +08:00
xiluo
75818b1e25 fix(antigravity): add warn-level logging to silent failure paths in FetchAntigravityModels
Add log.Warnf calls to all 7 silent return nil paths so operators can
diagnose why specific antigravity accounts fail to fetch models and get
unregistered without any log trail.

Covers: token errors, request creation failures, context cancellation,
network errors (after exhausting fallback URLs), body read errors,
unexpected HTTP status codes, and missing models field in response.
2026-02-13 18:01:46 +08:00
이대희
a45c6defa7 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-13 15:07:32 +09:00
Luis Pater
cbe56955a9 Merge pull request #227 from router-for-me/plus
v6.8.15
2026-02-13 12:50:52 +08:00
Luis Pater
8ea6ac913d Merge branch 'main' into plus 2026-02-13 12:50:39 +08:00
Luis Pater
ae1e8a5191 chore(runtime, registry): update Codex client version and GPT-5.3 model creation date 2026-02-13 12:47:48 +08:00
Luis Pater
b3ccc55f09 Merge pull request #1574 from fbettag/feat/gpt-5.3-codex-spark
feat(registry): add gpt-5.3-codex-spark model definition
2026-02-13 12:46:08 +08:00
이대희
40bee3e8d9 Merge branch 'main' into feature/canceled 2026-02-13 13:37:55 +09:00
Franz Bettag
1ce56d7413 Update internal/registry/model_definitions_static_data.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 23:37:27 +01:00
Franz Bettag
41a78be3a2 feat(registry): add gpt-5.3-codex-spark model definition 2026-02-12 23:24:08 +01:00
Luis Pater
1ff5de9a31 docs(readme): add CLIProxyAPI Dashboard to project list 2026-02-13 00:40:39 +08:00
Luis Pater
46a6853046 Merge pull request #1568 from itsmylife44/add-cliproxyapi-dashboard
Add CLIProxyAPI Dashboard to 'Who is with us?' section
2026-02-13 00:37:41 +08:00
xSpaM
4b2d40bd67 Add CLIProxyAPI Dashboard to 'Who is with us?' section 2026-02-12 17:15:46 +01:00
Luis Pater
726f1a590c Merge branch 'router-for-me:main' into main 2026-02-12 22:43:44 +08:00
Luis Pater
575881cb59 feat(registry): add new model definition for MiniMax-M2.5 2026-02-12 22:43:01 +08:00
Luis Pater
d02df0141b Merge pull request #224 from Buywatermelon/fix/kiro-assistant-first-message
fix(kiro): prepend placeholder user message when conversation starts with assistant role
2026-02-12 15:11:10 +08:00
Luis Pater
e4bc9da913 Merge pull request #220 from jellyfish-p/main
fix(kiro): 修复之前提交的错误的application/cbor请求处理逻辑
2026-02-12 15:10:42 +08:00
Luis Pater
8c6be49625 Merge pull request #218 from ClubWeGo/fix/merge-assistant-tool-calls
fix: prevent merging assistant messages with tool_calls
2026-02-12 15:10:00 +08:00
Luis Pater
c727e4251f ci(github): trigger Docker image workflow on version tags matching v* 2026-02-12 15:09:16 +08:00
Luis Pater
99266be998 Merge pull request #216 from starsdream666/main
增加kiro新模型并根据其他提供商同模型配置Thinking
2026-02-12 15:08:37 +08:00
Luis Pater
d0f3fd96f8 Merge pull request #225 from router-for-me/main
v6.8.13
2026-02-12 15:06:32 +08:00
hkfires
f361b2716d feat(registry): add glm-5 model to iflow 2026-02-12 11:13:28 +08:00
y
086d8d0d0b fix(kiro): prepend placeholder user message when conversation starts with assistant role
Kiro/AmazonQ API requires the conversation history to start with a user message.
Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
which is valid for the native Claude API but causes 'Improperly formed request' (400)
on the Kiro endpoint.

This fix detects when the first message has role=assistant and prepends a minimal
placeholder user message ('.') to satisfy the Kiro API's message ordering requirement.

Upstream error: {"message":"Improperly formed request.","reason":null}
Verified: original request returns 400, fixed request returns 200.
2026-02-12 11:09:47 +08:00
jellyfish-p
627dee1dac fix(kiro): 修复之前提交的错误的application/cbor请求处理逻辑 2026-02-12 09:57:34 +08:00
이대희
93147dddeb Improves error handling for canceled requests
Adds explicit handling for context.Canceled errors in the reverse proxy error handler to return 499 status code without logging, which is more appropriate for client-side cancellations during polling.

Also adds a test case to verify this behavior.
2026-02-12 10:39:45 +09:00
이대희
c0f9b15a58 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-12 10:33:49 +09:00
이대희
6f2fbdcbae Update internal/api/modules/amp/proxy.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 10:30:05 +09:00
Darley
55c3197fb8 fix(kiro): merge adjacent assistant messages while preserving tool_calls 2026-02-12 07:30:36 +08:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
Darley
5a2cf0d53c fix: prevent merging assistant messages with tool_calls
Adjacent assistant messages where any message contains tool_calls
were being merged by MergeAdjacentMessages, causing tool_calls to
be silently dropped. This led to orphaned tool results that could
not match any toolUse in history, resulting in Kiro API returning
'Improperly formed request.'

Now assistant messages with tool_calls are kept separate during
merge, preserving the tool call chain integrity.
2026-02-12 01:53:40 +08:00
starsdream666
2573358173 根据其他提供商同模型配置Thinking 2026-02-12 00:41:13 +08:00
starsdream666
09cd3cff91 增加kiro新模型:deepseek-3.2,minimax-m2.1,qwen3-coder-next,gpt-4o,gpt-4,gpt-4-turbo,gpt-3.5-turbo 2026-02-12 00:35:24 +08:00
starsdream666
ab0bf1b517 Merge branch 'router-for-me:main' into main 2026-02-11 16:20:20 +00:00
Luis Pater
58e09f8e5f Merge pull request #1542 from APE-147/fix/gemini-antigravity-schema-sanitization
fix(schema): sanitize Gemini-incompatible tool metadata fields
2026-02-11 21:34:04 +08:00
Luis Pater
2334a2b174 Merge branch 'router-for-me:main' into main 2026-02-11 21:09:34 +08:00
Luis Pater
bc61bf36b2 Merge pull request #214 from anilcancakir/fix/github-copilot-model-alias-suffix
fix(auth): strip model suffix in GitHub Copilot executor before upstream call
2026-02-11 21:06:58 +08:00
Luis Pater
7726a44ca2 Merge pull request #212 from Skyuno/fix/orphaned-tool-results
fix(kiro): filter orphaned tool_results from compacted conversations
2026-02-11 21:06:20 +08:00
Luis Pater
dc55fb0ce3 Merge pull request #211 from Skyuno/fix/kiro-websearch
fix(kiro): fully implement Kiro web search tool via MCP integration
2026-02-11 21:05:21 +08:00
Luis Pater
a146c6c0aa Merge pull request #1523 from xxddff/feature/removeUserField
fix(codex): remove unsupported 'user' field from /v1/responses payload
2026-02-11 20:38:16 +08:00
Luis Pater
4c133d3ea9 test(sdk/watcher): add tests for excluded models merging and priority parsing logic
- Added unit tests for combining OAuth excluded models across global and attribute-specific scopes.
- Implemented priority attribute parsing with support for different formats and trimming.
2026-02-11 20:35:13 +08:00
starsdream666
544238772a Merge branch 'router-for-me:main' into main 2026-02-11 10:58:06 +00:00
sususu98
f3ccd85ba1 feat(gemini-cli): add Google One login and improve auto-discovery
Add Google One personal account login to Gemini CLI OAuth flow:
- CLI --login shows mode menu (Code Assist vs Google One)
- Web management API accepts project_id=GOOGLE_ONE sentinel
- Auto-discover project via onboardUser without cloudaicompanionProject when project is unresolved

Improve robustness of auto-discovery and token handling:
- Add context-aware auto-discovery polling (30s timeout, 2s interval)
- Distinguish network errors from project-selection-required errors
- Refresh expired access tokens in readAuthFile before project lookup
- Extend project_id auto-fill to gemini auth type (was antigravity-only)

Unify credential file naming to geminicli- prefix for both CLI and web.

Add extractAccessToken unit tests (9 cases).
2026-02-11 17:53:03 +08:00
RGBadmin
dc279de443 refactor: reduce code duplication in extractExcludedModelsFromMetadata 2026-02-11 15:57:16 +08:00
RGBadmin
bf1634bda0 refactor: simplify per-account excluded_models merge in routing 2026-02-11 15:57:15 +08:00
Nathan
166d2d24d9 fix(schema): remove Gemini-incompatible tool metadata fields
Sanitize tool schemas by stripping prefill, enumTitles, $id, and patternProperties to prevent Gemini INVALID_ARGUMENT 400 errors, and add unit and executor-level tests to lock in the behavior.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 18:29:17 +11:00
RGBadmin
4cbcc835d1 feat: read per-account excluded_models at routing time 2026-02-11 15:21:19 +08:00
RGBadmin
b93026d83a feat: merge per-account excluded_models with global config 2026-02-11 15:21:15 +08:00
RGBadmin
5ed2133ff9 feat: add per-account excluded_models and priority parsing 2026-02-11 15:21:12 +08:00
Luis Pater
e9dd44e623 Merge pull request #209 from Buywatermelon/feature/default-kiro-aliases
feat(config): add default Kiro model aliases for standard Claude model names
2026-02-11 15:09:00 +08:00
Luis Pater
cc8c4ffb5f Merge branch 'router-for-me:main' into main 2026-02-11 15:07:06 +08:00
Luis Pater
1510bfcb6f fix(translator): improve content handling for system and user messages
- Added support for single and array-based `content` cases.
- Enhanced `system_instruction` structure population logic.
- Improved handling of user role assignment for string-based `content`.
2026-02-11 15:04:01 +08:00
Anilcan Cakir
bcd2208b51 fix(auth): strip model suffix in GitHub Copilot executor before upstream call
GitHub Copilot API rejects model names with suffixes (e.g. claude-opus-4.6(medium)).
The OAuthModelAlias resolution correctly maps aliases like 'opus(medium)' to
'claude-opus-4.6(medium)' preserving the suffix, but the executor must strip the
suffix before sending to the upstream API since Copilot only accepts bare model names.

Update normalizeModel in github_copilot_executor to strip suffixes using
thinking.ParseSuffix, matching the pattern used by other executors.

Also add test coverage for:
- OAuthModelAliasChannel github-copilot and kiro channel resolution
- Suffix preservation in alias resolution for github-copilot
- normalizeModel suffix stripping in github_copilot_executor
2026-02-10 23:34:19 +03:00
Skyuno
09b19f5c4e fix(kiro): filter orphaned tool_results from compacted conversations 2026-02-11 00:23:05 +08:00
Skyuno
7b01ca0e2e fix(kiro): implement web search MCP integration for streaming and non-streaming paths
Add complete web search functionality that routes pure web_search requests to the Kiro MCP endpoint instead of the normal GAR API.

Executor changes (kiro_executor.go):

- Add web_search detection in Execute() and ExecuteStream() entry points using HasWebSearchTool() to intercept pure web_search requests before normal processing

- Add 'kiro' format passthrough in buildKiroPayloadForFormat() for pre-built payloads used by callKiroRawAndBuffer()

- Implement handleWebSearchStream(): streaming search loop with MCP search -> InjectToolResultsClaude -> callKiroAndBuffer, supporting up to 5 search iterations with model-driven re-search

- Implement handleWebSearch(): non-streaming path that performs single MCP search, injects tool results, calls normal Execute path, and appends server_tool_use indicators to response

- Add helper methods: callKiroAndBuffer(), callKiroRawAndBuffer(), callKiroDirectStream(), sendFallbackText(), executeNonStreamFallback()

Web search core logic (kiro_websearch.go) [NEW]:

- Define MCP JSON-RPC 2.0 types (McpRequest, McpResponse, McpResult, McpContent, McpError)

- Define WebSearchResults/WebSearchResult structs for parsing MCP search results

- HasWebSearchTool(): detect pure web_search requests (single-tool array only)

- ContainsWebSearchTool(): detect web_search in mixed-tool arrays

- ExtractSearchQuery(): parse search query from Claude Code's tool_use message format

- CreateMcpRequest(): build MCP tools/call request with Kiro-compatible ID format

- InjectToolResultsClaude(): append assistant tool_use + user tool_result messages to Claude-format payload for GAR translation pipeline

- InjectToolResults(): modify Kiro-format payload directly with toolResults in currentMessage context

- InjectSearchIndicatorsInResponse(): prepend server_tool_use + web_search_tool_result content blocks to non-streaming response for Claude Code search count display

- ReplaceWebSearchToolDescription(): swap restrictive Kiro tool description with minimal re-search-friendly version

- StripWebSearchTool(): remove web_search from tools array

- FormatSearchContextPrompt() / FormatToolResultText(): format search results for injection

- SSE event generation: SseEvent type, GenerateWebSearchEvents() (11-event sequence), GenerateSearchIndicatorEvents() (server_tool_use + web_search_tool_result pairs)

- Stream analysis: AnalyzeBufferedStream() to detect stop_reason and web_search tool_use in buffered chunks, FilterChunksForClient() to strip tool_use blocks and adjust indices, AdjustSSEChunk() / AdjustStreamIndices() for content block index offset management

MCP API handler (kiro_websearch_handler.go) [NEW]:

- WebSearchHandler struct with MCP endpoint, HTTP client, auth token, fingerprint, and custom auth attributes

- FetchToolDescription(): sync.Once-guarded MCP tools/list call to cache web_search tool description

- GetWebSearchDescription(): thread-safe cached description retrieval

- CallMcpAPI(): MCP API caller with retry logic (exponential backoff, retryable on 502/503/504), AWS-aligned headers via setMcpHeaders()

- ParseSearchResults(): extract WebSearchResults from MCP JSON-RPC response

- setMcpHeaders(): set Content-Type, Kiro agent headers, dynamic fingerprint User-Agent, AWS SDK identifiers, Bearer auth, and custom auth attributes

Claude request translation (kiro_claude_request.go):

- Rename web_search -> remote_web_search in convertClaudeToolsToKiro() with dynamic description from GetWebSearchDescription() or hardcoded fallback

- Rename web_search -> remote_web_search in BuildAssistantMessageStruct() for tool_use content blocks

- Add remoteWebSearchDescription constant as fallback when MCP tools/list hasn't been fetched
2026-02-11 00:02:30 +08:00
starsdream666
9c65e17a21 Merge branch 'router-for-me:main' into main 2026-02-10 14:41:20 +00:00
Skyuno
fe6fc628ed Revert "fix: filter out web_search/websearch tools unsupported by Kiro API"
This reverts commit 5dc936a9a4.
2026-02-10 22:24:46 +08:00
Skyuno
8192eeabc8 Revert "feat: inject web_search alternative hint instead of silently filtering"
This reverts commit 3c7a5afdcc.
2026-02-10 22:24:46 +08:00
y
c3f1cdd7e5 feat(config): add default Kiro model aliases for standard Claude model names
Kiro models are exposed with kiro- prefix (e.g., kiro-claude-sonnet-4-5),
which prevents clients like Claude Code from using standard model names
(e.g., claude-sonnet-4-20250514).

This change injects default oauth-model-alias entries for the kiro channel
when no user-configured aliases exist, following the same pattern as the
existing Antigravity defaults. The aliases map standard Claude model names
(both with and without date suffixes) to their kiro-prefixed counterparts.

Default aliases added:
- claude-sonnet-4-5-20250929 / claude-sonnet-4-5 → kiro-claude-sonnet-4-5
- claude-sonnet-4-20250514 / claude-sonnet-4 → kiro-claude-sonnet-4
- claude-opus-4-6 → kiro-claude-opus-4-6
- claude-opus-4-5-20251101 / claude-opus-4-5 → kiro-claude-opus-4-5
- claude-haiku-4-5-20251001 / claude-haiku-4-5 → kiro-claude-haiku-4-5

All aliases use fork: true to preserve the original kiro-* names.
User-configured kiro aliases are respected and not overridden.

Closes router-for-me/CLIProxyAPIPlus#208
2026-02-10 19:01:07 +08:00
Chén Mù
c6bd91b86b Merge pull request #1519 from router-for-me/thinking
feat(translator): support Claude thinking type adaptive
2026-02-10 18:31:56 +08:00
이대희
ce0c6aa82b Update internal/api/modules/amp/proxy.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-10 19:07:49 +09:00
hkfires
349ddcaa89 fix(registry): correct max completion tokens for opus 4.6 thinking 2026-02-10 18:05:40 +08:00
xxddff
bb9fe52f1e Update internal/translator/codex/openai/responses/codex_openai-responses_request_test.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-10 18:24:58 +09:00
xxddff
afe4c1bfb7 更新internal/translator/codex/openai/responses/codex_openai-responses_request.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-10 18:24:26 +09:00
이대희
3c85d2a4d7 feature(proxy): Adds special handling for client cancellations in proxy error handler
Silences logging for client cancellations during polling to reduce noise in logs.
Client-side cancellations are common during long-running operations and should not be treated as errors.
2026-02-10 18:02:08 +09:00
xxddff
865af9f19e Implement test for user field deletion
Add test to verify deletion of user field in response
2026-02-10 17:38:49 +09:00
xxddff
2b97cb98b5 Delete 'user' field from raw JSON
Remove the 'user' field from the raw JSON as requested.
2026-02-10 17:35:54 +09:00
hkfires
938a799263 feat(translator): support Claude thinking type adaptive 2026-02-10 16:20:32 +08:00
Luis Pater
e17d4f8d98 Merge pull request #207 from router-for-me/plus
v6.8.9
2026-02-10 15:43:45 +08:00
Luis Pater
c8cae1f74d Merge branch 'main' into plus 2026-02-10 15:43:31 +08:00
Luis Pater
0040d78496 refactor(sdk): simplify provider lifecycle and registration logic 2026-02-10 15:39:26 +08:00
hkfires
896de027cc docs(config): reorder antigravity model alias example 2026-02-10 10:13:54 +08:00
hkfires
fc329ebf37 docs(config): simplify oauth model alias example 2026-02-10 10:12:28 +08:00
starsdream666
15bc99f6ea Merge branch 'router-for-me:main' into main 2026-02-10 01:45:05 +00:00
Luis Pater
91841a5519 Merge branch 'router-for-me:main' into main 2026-02-10 02:10:29 +08:00
Luis Pater
eaab1d6824 Merge pull request #1506 from masrurimz/fix-sse-model-mapping
fix(amp): rewrite response.model in Responses API SSE events
2026-02-10 02:08:11 +08:00
Muhammad Zahid Masruri
0cfe310df6 ci: retrigger workflows
Amp-Thread-ID: https://ampcode.com/threads/T-019c264f-1cb9-7420-a68b-876030db6716
2026-02-10 00:09:11 +07:00
Muhammad Zahid Masruri
918b6955e4 fix(amp): rewrite model name in response.model for Responses API SSE events
The ResponseRewriter's modelFieldPaths was missing 'response.model',
causing the mapped model name to leak through SSE streaming events
(response.created, response.in_progress, response.completed) in the
OpenAI Responses API (/v1/responses).

This caused Amp CLI to report 'Unknown OpenAI model' errors when
model mapping was active (e.g., gpt-5.2-codex -> gpt-5.3-codex),
because the mapped name reached Amp's backend via telemetry.

Also sorted modelFieldPaths alphabetically per review feedback
and added regression tests for all rewrite paths.

Fixes #1463
2026-02-09 23:52:59 +07:00
starsdream666
3ec7991e5f Merge branch 'router-for-me:main' into main 2026-02-09 14:18:04 +00:00
Luis Pater
532fbf00d4 Merge pull request #204 from router-for-me/plus
v6.8.7
2026-02-09 20:00:36 +08:00
Luis Pater
45b6fffd7f Merge branch 'main' into plus 2026-02-09 20:00:16 +08:00
Luis Pater
5a3eb08739 Merge pull request #1502 from router-for-me/iflow
feat(executor): add session ID and HMAC-SHA256 signature generation for iFlow API requests
2026-02-09 19:56:12 +08:00
Luis Pater
0dff329162 Merge pull request #1492 from router-for-me/management
fix(management): ensure management.html is available synchronously and improve asset sync handling
2026-02-09 19:55:21 +08:00
hkfires
49c1740b47 feat(executor): add session ID and HMAC-SHA256 signature generation for iFlow API requests 2026-02-09 19:29:42 +08:00
hkfires
3fbee51e9f fix(management): ensure management.html is available synchronously and improve asset sync handling 2026-02-09 08:32:58 +08:00
Luis Pater
a3dc56d2a0 Merge branch 'router-for-me:main' into main 2026-02-09 02:07:02 +08:00
Luis Pater
63643c44a1 Fixed: #1484
fix(translator): restructure message content handling to support multiple content types

- Consolidated `input_text` and `output_text` handling into a single case.
- Added support for processing `input_image` content with associated URLs.
2026-02-09 02:05:38 +08:00
Luis Pater
1d93608dbe Merge pull request #203 from JokerRun/fix/copilot-premium-usage-inflation
fix(copilot): prevent premium request count inflation for Claude models
2026-02-08 20:42:51 +08:00
Luis Pater
d125b7de92 Merge pull request #199 from ravindrabarthwal/add-claude-opus-4.6-github-copilot
feat: add Claude Opus 4.6 to GitHub Copilot models
2026-02-08 20:41:20 +08:00
Luis Pater
d5654ee316 Merge branch 'router-for-me:main' into main 2026-02-08 20:40:18 +08:00
Luis Pater
3b34521ad9 Merge pull request #1479 from router-for-me/management
refactor(management): streamline control panel management and implement sync throttling
2026-02-08 20:37:29 +08:00
hkfires
7197fb350b fix(config): prune default descendants when merging new yaml nodes 2026-02-08 19:05:52 +08:00
hkfires
6e349bfcc7 fix(config): avoid writing known defaults during merge 2026-02-08 18:47:44 +08:00
hkfires
234056072d refactor(management): streamline control panel management and implement sync throttling 2026-02-08 10:42:49 +08:00
rico
76330f4bff feat(copilot): add Claude Opus 4.6 model definition
> 添加 copilot claude opus 4.6 支持 (ref: PR #199)
2026-02-08 02:38:06 +08:00
rico
d468eec6ec fix(copilot): prevent premium request count inflation for Claude models
> Copilot Premium usage significantly amplified when using amp

- Add X-Initiator header (user/agent) based on last message role to
  prevent Copilot from billing all requests as premium user-initiated
- Add flattenAssistantContent() to convert assistant content from array
  to string, preventing Claude from re-answering all previous prompts
- Align Copilot headers (User-Agent, Editor-Version, Openai-Intent) with
  pi-ai reference implementation

Closes #113

Amp-Thread-ID: https://ampcode.com/threads/T-019c392b-736e-7489-a06b-f94f7c75f7c0
Co-authored-by: Amp <amp@ampcode.com>
2026-02-08 02:22:10 +08:00
starsdream666
40e85a6759 Merge branch 'router-for-me:main' into main 2026-02-07 16:37:51 +00:00
Ravindra Barthwal
9bc6cc5b41 feat: add Claude Opus 4.6 to GitHub Copilot models
GitHub Copilot now supports claude-opus-4.6 but it was missing from
the proxy's model definitions. Fixes #196.
2026-02-07 14:58:34 +05:30
starsdream666
cc116ce67d Merge branch 'router-for-me:main' into main 2026-02-06 16:11:26 +00:00
starsdream666
40efc2ba43 修改工作流 2026-02-06 03:29:31 +00:00
204 changed files with 22377 additions and 2087 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:
go-version: '>=1.24.0'
go-version: '>=1.26.0'
cache: true
- name: Generate Build Metadata
run: |

3
.gitignore vendored
View File

@@ -3,10 +3,11 @@ cli-proxy-api
cliproxy
*.exe
# Configuration
config.yaml
.env
.mcp.json
# Generated content
bin/*
logs/*

View File

@@ -1,4 +1,4 @@
FROM golang:1.24-alpine AS builder
FROM golang:1.26-alpine AS builder
WORKDIR /app

View File

@@ -27,6 +27,51 @@ The Plus release stays in lockstep with the mainline features.
## Kiro Authentication
### CLI Login
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
**AWS Builder ID** (recommended):
```bash
# Device code flow
./CLIProxyAPI --kiro-aws-login
# Authorization code flow
./CLIProxyAPI --kiro-aws-authcode
```
**Import token from Kiro IDE:**
```bash
./CLIProxyAPI --kiro-import
```
To get a token from Kiro IDE:
1. Open Kiro IDE and login with Google (or GitHub)
2. Find the token file: `~/.kiro/kiro-auth-token.json`
3. Run: `./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC):**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# Specify region
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**Additional flags:**
| Flag | Description |
|------|-------------|
| `--no-browser` | Don't open browser automatically, print URL instead |
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
### Web-based OAuth Login
Access the Kiro OAuth web interface at:

View File

@@ -27,6 +27,51 @@
## Kiro 认证
### 命令行登录
> **注意:** 由于 AWS Cognito 限制Google/GitHub 登录不可用于第三方应用。
**AWS Builder ID**(推荐):
```bash
# 设备码流程
./CLIProxyAPI --kiro-aws-login
# 授权码流程
./CLIProxyAPI --kiro-aws-authcode
```
**从 Kiro IDE 导入令牌:**
```bash
./CLIProxyAPI --kiro-import
```
获取令牌步骤:
1. 打开 Kiro IDE使用 Google或 GitHub登录
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
3. 运行:`./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC)**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# 指定区域
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**附加参数:**
| 参数 | 说明 |
|------|------|
| `--no-browser` | 不自动打开浏览器,打印 URL |
| `--no-incognito` | 使用已有浏览器会话Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
| `--kiro-idc-start-url` | IDC Start URL`--kiro-idc-login` 必需) |
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1` |
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
@@ -26,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -70,8 +72,10 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
@@ -83,19 +87,27 @@ func main() {
var kiroAWSLogin bool
var kiroAWSAuthCode bool
var kiroImport bool
var kiroIDCLogin bool
var kiroIDCStartURL string
var kiroIDCRegion string
var kiroIDCFlow string
var githubCopilotLogin bool
var projectID string
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
var noIncognito bool
var useIncognito bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
@@ -109,11 +121,17 @@ func main() {
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -475,7 +493,7 @@ func main() {
}
// Register built-in access providers before constructing services.
configaccess.Register()
configaccess.Register(&cfg.SDKConfig)
// Handle different command modes based on the provided flags.
@@ -494,11 +512,16 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
} else if kiloLogin {
cmd.DoKiloLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
@@ -511,24 +534,34 @@ func main() {
// Note: This config mutation is safe - auth commands exit after completion
// and don't share config with StartService (which is in the else branch)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroLogin(cfg, options)
} else if kiroGoogleLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroGoogleLogin(cfg, options)
} else if kiroAWSLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSLogin(cfg, options)
} else if kiroAWSAuthCode {
// For Kiro auth with authorization code flow (better UX)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
} else if kiroImport {
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroImport(cfg, options)
} else if kiroIDCLogin {
// For Kiro IDC auth, default to incognito mode for multi-account support
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {
@@ -536,15 +569,89 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
// 初始化并启动 Kiro token 后台刷新
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
cmd.StartService(cfg, configFilePath, password)
}
}

View File

@@ -1,6 +1,6 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
host: ''
# Server port
port: 8317
@@ -8,8 +8,8 @@ port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
cert: ''
key: ''
# Management API settings
remote-management:
@@ -20,22 +20,22 @@ remote-management:
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
secret-key: ""
secret-key: ''
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
# Authentication directory (supports ~ for home directory)
auth-dir: "~/.cli-proxy-api"
auth-dir: '~/.cli-proxy-api'
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
- 'your-api-key-1'
- 'your-api-key-2'
- 'your-api-key-3'
# Enable debug logging
debug: false
@@ -43,7 +43,7 @@ debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
addr: '127.0.0.1:8316'
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
@@ -68,11 +68,15 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
proxy-url: ''
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
@@ -86,7 +90,7 @@ quota-exceeded:
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
strategy: 'round-robin' # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
@@ -160,17 +164,43 @@ nonstream-keepalive-interval: 0
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
#kiro:
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
# agent-task-type: "" # optional: "vibe" or empty (API default)
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
# - access-token: "aoaAAAAA..." # or provide tokens directly
# refresh-token: "aorAAAAA..."
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
# Kilocode (OAuth-based code assistant)
# Note: Kilocode uses OAuth device flow authentication.
# Use the CLI command: ./server --kilo-login
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
# oauth-model-alias:
# kilo:
# - name: "minimax/minimax-m2.5:free"
# alias: "minimax-m2.5"
# - name: "z-ai/glm-5:free"
# alias: "glm-5"
# oauth-excluded-models:
# kilo:
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
# - "*:free" # wildcard matching suffix (e.g. all free models)
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
@@ -239,7 +269,7 @@ nonstream-keepalive-interval: 0
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
#oauth-model-alias:
# oauth-model-alias:
# antigravity:
# - name: "rev19-uic3-1p"
# alias: "gemini-2.5-computer-use-preview-10-2025"
@@ -265,9 +295,6 @@ nonstream-keepalive-interval: 0
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"

View File

@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
```go
import (
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
```
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
## Provider Registry
Providers are registered globally and then attached to a `Manager` as a snapshot:
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
- Registration order is preserved the first time each `type` is seen.
- `RegisteredProviders()` returns the providers in that order.
## Manager Lifecycle
```go
manager := sdkaccess.NewManager()
providers, err := sdkaccess.BuildProviders(cfg)
if err != nil {
return err
}
manager.SetProviders(providers)
manager.SetProviders(sdkaccess.RegisteredProviders())
```
* `NewManager` constructs an empty manager.
* `SetProviders` replaces the provider slice using a defensive copy.
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
## Authenticating Requests
```go
result, err := manager.Authenticate(ctx, req)
result, authErr := manager.Authenticate(ctx, req)
switch {
case err == nil:
case authErr == nil:
// Authentication succeeded; result describes the provider and principal.
case errors.Is(err, sdkaccess.ErrNoCredentials):
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
// No recognizable credentials were supplied.
case errors.Is(err, sdkaccess.ErrInvalidCredential):
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
// Supplied credentials were present but rejected.
default:
// Transport-level failure was returned by a provider.
// Internal/transport failure was returned by a provider.
}
```
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
## Configuration Layout
## Built-in `config-api-key` Provider
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
The proxy includes one built-in access provider:
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
```yaml
auth:
providers:
- name: inline-api
type: config-api-key
api-keys:
- sk-test-123
- sk-prod-456
api-keys:
- sk-test-123
- sk-prod-456
```
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
## Loading Providers from External Go Modules
### Loading providers from external SDK modules
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
```yaml
auth:
providers:
- name: partner-auth
type: partner-token
sdk: github.com/acme/xplatform/sdk/access/providers/partner
config:
region: us-west-2
audience: cli-proxy
```
To consume a provider shipped in another Go module, import it for its registration side effect:
```go
import (
@@ -89,19 +80,11 @@ import (
)
```
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
## Built-in Providers
The SDK ships with one provider out of the box:
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
### Metadata and auditing
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
## Writing Custom Providers
@@ -110,13 +93,13 @@ type customProvider struct{}
func (p *customProvider) Identifier() string { return "my-provider" }
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
token := r.Header.Get("X-Custom")
if token == "" {
return nil, sdkaccess.ErrNoCredentials
return nil, sdkaccess.NewNotHandledError()
}
if token != "expected" {
return nil, sdkaccess.ErrInvalidCredential
return nil, sdkaccess.NewInvalidCredentialError()
}
return &sdkaccess.Result{
Provider: p.Identifier(),
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
}
func init() {
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
return &customProvider{}, nil
})
sdkaccess.RegisterProvider("custom", &customProvider{})
}
```
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
## Error Semantics
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
## Integration with cliproxy Service
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
```go
coreCfg, _ := config.LoadConfig("config.yaml")
providers, _ := sdkaccess.BuildProviders(coreCfg)
manager := sdkaccess.NewManager()
manager.SetProviders(providers)
accessManager := sdkaccess.NewManager()
svc, _ := cliproxy.NewBuilder().
WithConfig(coreCfg).
WithAccessManager(manager).
WithConfigPath("config.yaml").
WithRequestAccessManager(accessManager).
Build()
```
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
### Hot reloading providers
### Hot reloading
When configuration changes, rebuild providers and swap them into the manager:
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
```go
providers, err := sdkaccess.BuildProviders(newCfg)
if err != nil {
log.Errorf("reload auth providers failed: %v", err)
return
}
accessManager.SetProviders(providers)
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
configaccess.Register(&newCfg.SDKConfig)
accessManager.SetProviders(sdkaccess.RegisteredProviders())
```
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.

View File

@@ -7,80 +7,71 @@
```go
import (
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
```
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
## Provider Registry
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
- 每个 `type` 第一次出现时会记录其注册顺序。
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
## 管理器生命周期
```go
manager := sdkaccess.NewManager()
providers, err := sdkaccess.BuildProviders(cfg)
if err != nil {
return err
}
manager.SetProviders(providers)
manager.SetProviders(sdkaccess.RegisteredProviders())
```
- `NewManager` 创建空管理器。
- `SetProviders` 替换提供者切片并做防御性拷贝。
- `Providers` 返回适合并发读取的快照。
- `BuildProviders``config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
如果管理器本身为 `nil` 或未配置任何 provider调用会返回 `nil, nil`,可视为关闭访问控制。
## 认证请求
```go
result, err := manager.Authenticate(ctx, req)
result, authErr := manager.Authenticate(ctx, req)
switch {
case err == nil:
case authErr == nil:
// Authentication succeeded; result carries provider and principal.
case errors.Is(err, sdkaccess.ErrNoCredentials):
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
// No recognizable credentials were supplied.
case errors.Is(err, sdkaccess.ErrInvalidCredential):
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
// Credentials were present but rejected.
default:
// Provider surfaced a transport-level failure.
}
```
`Manager.Authenticate`配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` `ErrInvalidCredential`会在遍历结束后汇总给调用方。
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
`Manager.Authenticate` 按顺序遍历 provider遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
## 配置结构
## 内建 `config-api-key` Provider
`config.yaml``auth.providers` 下定义访问提供者:
代理内置一个访问提供者:
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`
- 凭证来源:`Authorization: Bearer``X-Goog-Api-Key``X-Api-Key``?key=``?auth_token=`
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
```yaml
auth:
providers:
- name: inline-api
type: config-api-key
api-keys:
- sk-test-123
- sk-prod-456
api-keys:
- sk-test-123
- sk-prod-456
```
条目映射到 `config.AccessProvider``name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
## 引入外部 Go 模块提供者
### 引入外部 SDK 提供者
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
```yaml
auth:
providers:
- name: partner-auth
type: partner-token
sdk: github.com/acme/xplatform/sdk/access/providers/partner
config:
region: us-west-2
audience: cli-proxy
```
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
```go
import (
@@ -89,19 +80,11 @@ import (
)
```
通过空白标识符导入可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`
## 内建提供者
当前 SDK 默认内置:
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer``X-Goog-Api-Key``X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`
### 元数据与审计
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization``x-goog-api-key``x-api-key``query-key`)。自定义提供者同样可以填充该 Map以便丰富日志与审计场景。
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization``x-goog-api-key``x-api-key``query-key``query-auth-token`)。自定义提供者同样可以填充该 Map以便丰富日志与审计场景。
## 编写自定义提供者
@@ -110,13 +93,13 @@ type customProvider struct{}
func (p *customProvider) Identifier() string { return "my-provider" }
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
token := r.Header.Get("X-Custom")
if token == "" {
return nil, sdkaccess.ErrNoCredentials
return nil, sdkaccess.NewNotHandledError()
}
if token != "expected" {
return nil, sdkaccess.ErrInvalidCredential
return nil, sdkaccess.NewInvalidCredentialError()
}
return &sdkaccess.Result{
Provider: p.Identifier(),
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
}
func init() {
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
return &customProvider{}, nil
})
sdkaccess.RegisterProvider("custom", &customProvider{})
}
```
自定义提供者需要实现 `Identifier()``Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置
自定义提供者需要实现 `Identifier()``Authenticate()`。在 `init`用已初始化实例调用 `RegisterProvider` 注册到全局 registry
## 错误语义
- `ErrNoCredentials`:任何提供者都未识别到凭证。
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计
- `NewNoCredentialsError()``AuthErrorCodeNoCredentials`未提供或未识别到凭证。HTTP 401
- `NewInvalidCredentialError()``AuthErrorCodeInvalidCredential`凭证存在但校验失败。HTTP 401
- `NewNotHandledError()``AuthErrorCodeNotHandled`:告诉管理器跳到下一个 provider
- `NewInternalAuthError(message, cause)``AuthErrorCodeInternal`):网络/系统错误。HTTP 500
自定义错误(例如网络异常)会马上冒泡返回。
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
## 与 cliproxy 集成
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
```go
coreCfg, _ := config.LoadConfig("config.yaml")
providers, _ := sdkaccess.BuildProviders(coreCfg)
manager := sdkaccess.NewManager()
manager.SetProviders(providers)
accessManager := sdkaccess.NewManager()
svc, _ := cliproxy.NewBuilder().
WithConfig(coreCfg).
WithAccessManager(manager).
WithConfigPath("config.yaml").
WithRequestAccessManager(accessManager).
Build()
```
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中
### 动态热更新提供者
当配置发生变化时,可以重新构建提供者并替换当前列表
当配置发生变化时,刷新依赖配置的 provider然后重置 manager 的 provider 链
```go
providers, err := sdkaccess.BuildProviders(newCfg)
if err != nil {
log.Errorf("reload auth providers failed: %v", err)
return
}
accessManager.SetProviders(providers)
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
configaccess.Register(&newCfg.SDKConfig)
accessManager.SetProviders(sdkaccess.RegisteredProviders())
```
这一流程与 `cliproxy.Service.refreshAccessProviders``api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。

View File

@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
return clipexec.Response{}, errors.New("count tokens not implemented")
}
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
ch := make(chan clipexec.StreamChunk, 1)
go func() {
defer close(ch)
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
}()
return ch, nil
return &clipexec.StreamResult{Chunks: ch}, nil
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {

View File

@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}

23
go.mod
View File

@@ -1,9 +1,13 @@
module github.com/router-for-me/CLIProxyAPI/v6
go 1.24.0
go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/fxamacker/cbor/v2 v2.9.0
github.com/gin-gonic/gin v1.10.1
@@ -33,8 +37,16 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
@@ -42,6 +54,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -58,19 +71,27 @@ require (
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect

45
go.sum
View File

@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -4,19 +4,28 @@ import (
"context"
"net/http"
"strings"
"sync"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
var registerOnce sync.Once
// Register ensures the config-access provider is available to the access manager.
func Register() {
registerOnce.Do(func() {
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
})
func Register(cfg *sdkconfig.SDKConfig) {
if cfg == nil {
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
return
}
keys := normalizeKeys(cfg.APIKeys)
if len(keys) == 0 {
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
return
}
sdkaccess.RegisterProvider(
sdkaccess.AccessProviderTypeConfigAPIKey,
newProvider(sdkaccess.DefaultAccessProviderName, keys),
)
}
type provider struct {
@@ -24,34 +33,31 @@ type provider struct {
keys map[string]struct{}
}
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
name := cfg.Name
if name == "" {
name = sdkconfig.DefaultAccessProviderName
func newProvider(name string, keys []string) *provider {
providerName := strings.TrimSpace(name)
if providerName == "" {
providerName = sdkaccess.DefaultAccessProviderName
}
keys := make(map[string]struct{}, len(cfg.APIKeys))
for _, key := range cfg.APIKeys {
if key == "" {
continue
}
keys[key] = struct{}{}
keySet := make(map[string]struct{}, len(keys))
for _, key := range keys {
keySet[key] = struct{}{}
}
return &provider{name: name, keys: keys}, nil
return &provider{name: providerName, keys: keySet}
}
func (p *provider) Identifier() string {
if p == nil || p.name == "" {
return sdkconfig.DefaultAccessProviderName
return sdkaccess.DefaultAccessProviderName
}
return p.name
}
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
if p == nil {
return nil, sdkaccess.ErrNotHandled
return nil, sdkaccess.NewNotHandledError()
}
if len(p.keys) == 0 {
return nil, sdkaccess.ErrNotHandled
return nil, sdkaccess.NewNotHandledError()
}
authHeader := r.Header.Get("Authorization")
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
queryAuthToken = r.URL.Query().Get("auth_token")
}
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
return nil, sdkaccess.ErrNoCredentials
return nil, sdkaccess.NewNoCredentialsError()
}
apiKey := extractBearerToken(authHeader)
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
}
}
return nil, sdkaccess.ErrInvalidCredential
return nil, sdkaccess.NewInvalidCredentialError()
}
func extractBearerToken(header string) string {
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
}
return strings.TrimSpace(parts[1])
}
func normalizeKeys(keys []string) []string {
if len(keys) == 0 {
return nil
}
normalized := make([]string, 0, len(keys))
seen := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmedKey := strings.TrimSpace(key)
if trimmedKey == "" {
continue
}
if _, exists := seen[trimmedKey]; exists {
continue
}
seen[trimmedKey] = struct{}{}
normalized = append(normalized, trimmedKey)
}
if len(normalized) == 0 {
return nil
}
return normalized
}

View File

@@ -6,9 +6,9 @@ import (
"sort"
"strings"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
)
@@ -17,26 +17,26 @@ import (
// ordered provider slice along with the identifiers of providers that were added, updated, or
// removed compared to the previous configuration.
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
_ = oldCfg
if newCfg == nil {
return nil, nil, nil, nil, nil
}
result = sdkaccess.RegisteredProviders()
existingMap := make(map[string]sdkaccess.Provider, len(existing))
for _, provider := range existing {
if provider == nil {
providerID := identifierFromProvider(provider)
if providerID == "" {
continue
}
existingMap[provider.Identifier()] = provider
existingMap[providerID] = provider
}
oldCfgMap := accessProviderMap(oldCfg)
newEntries := collectProviderEntries(newCfg)
result = make([]sdkaccess.Provider, 0, len(newEntries))
finalIDs := make(map[string]struct{}, len(newEntries))
finalIDs := make(map[string]struct{}, len(result))
isInlineProvider := func(id string) bool {
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
}
appendChange := func(list *[]string, id string) {
if isInlineProvider(id) {
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
*list = append(*list, id)
}
for _, providerCfg := range newEntries {
key := providerIdentifier(providerCfg)
if key == "" {
for _, provider := range result {
providerID := identifierFromProvider(provider)
if providerID == "" {
continue
}
finalIDs[providerID] = struct{}{}
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
if oldCfgProvider, ok := oldCfgMap[key]; ok {
isAliased := oldCfgProvider == providerCfg
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
if existingProvider, okExisting := existingMap[key]; okExisting {
result = append(result, existingProvider)
finalIDs[key] = struct{}{}
continue
}
}
existingProvider, exists := existingMap[providerID]
if !exists {
appendChange(&added, providerID)
continue
}
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
if _, ok := oldCfgMap[key]; ok {
if _, existed := existingMap[key]; existed {
appendChange(&updated, key)
} else {
appendChange(&added, key)
}
} else {
appendChange(&added, key)
}
result = append(result, provider)
finalIDs[key] = struct{}{}
}
if len(result) == 0 {
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
key := providerIdentifier(inline)
if key != "" {
if oldCfgProvider, ok := oldCfgMap[key]; ok {
if providerConfigEqual(oldCfgProvider, inline) {
if existingProvider, okExisting := existingMap[key]; okExisting {
result = append(result, existingProvider)
finalIDs[key] = struct{}{}
goto inlineDone
}
}
}
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
if _, existed := existingMap[key]; existed {
appendChange(&updated, key)
} else if _, hadOld := oldCfgMap[key]; hadOld {
appendChange(&updated, key)
} else {
appendChange(&added, key)
}
result = append(result, provider)
finalIDs[key] = struct{}{}
}
}
inlineDone:
}
removedSet := make(map[string]struct{})
for id := range existingMap {
if _, ok := finalIDs[id]; !ok {
if isInlineProvider(id) {
continue
}
removedSet[id] = struct{}{}
if !providerInstanceEqual(existingProvider, provider) {
appendChange(&updated, providerID)
}
}
removed = make([]string, 0, len(removedSet))
for id := range removedSet {
removed = append(removed, id)
for providerID := range existingMap {
if _, exists := finalIDs[providerID]; exists {
continue
}
appendChange(&removed, providerID)
}
sort.Strings(added)
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
}
existing := manager.Providers()
configaccess.Register(&newCfg.SDKConfig)
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
if err != nil {
log.Errorf("failed to reconcile request auth providers: %v", err)
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
return false, nil
}
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
result := make(map[string]*sdkConfig.AccessProvider)
if cfg == nil {
return result
}
for i := range cfg.Access.Providers {
providerCfg := &cfg.Access.Providers[i]
if providerCfg.Type == "" {
continue
}
key := providerIdentifier(providerCfg)
if key == "" {
continue
}
result[key] = providerCfg
}
if len(result) == 0 && len(cfg.APIKeys) > 0 {
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
if key := providerIdentifier(provider); key != "" {
result[key] = provider
}
}
}
return result
}
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
for i := range cfg.Access.Providers {
providerCfg := &cfg.Access.Providers[i]
if providerCfg.Type == "" {
continue
}
if key := providerIdentifier(providerCfg); key != "" {
entries = append(entries, providerCfg)
}
}
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
entries = append(entries, inline)
}
}
return entries
}
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
func identifierFromProvider(provider sdkaccess.Provider) string {
if provider == nil {
return ""
}
if name := strings.TrimSpace(provider.Name); name != "" {
return name
}
typ := strings.TrimSpace(provider.Type)
if typ == "" {
return ""
}
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
return sdkConfig.DefaultAccessProviderName
}
return typ
return strings.TrimSpace(provider.Identifier())
}
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
if reflect.TypeOf(a) != reflect.TypeOf(b) {
return false
}
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
return false
valueA := reflect.ValueOf(a)
valueB := reflect.ValueOf(b)
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
return valueA.Pointer() == valueB.Pointer()
}
if !stringSetEqual(a.APIKeys, b.APIKeys) {
return false
}
if len(a.Config) != len(b.Config) {
return false
}
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
return false
}
return true
}
func stringSetEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
seen := make(map[string]int, len(a))
for _, val := range a {
seen[val]++
}
for _, val := range b {
count := seen[val]
if count == 0 {
return false
}
if count == 1 {
delete(seen, val)
} else {
seen[val] = count - 1
}
}
return len(seen) == 0
return reflect.DeepEqual(a, b)
}

View File

@@ -1,6 +1,7 @@
package management
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
if useCBORPayload {
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
if errEncode != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
return
}
requestBody = bytes.NewReader(cborPayload)
} else {
requestBody = strings.NewReader(body.Data)
}
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
return
}
// For CBOR upstream responses, decode into plain text or JSON string before returning.
responseBodyText := string(respBody)
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
responseBodyText = decodedBody
}
}
response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
Body: responseBodyText,
}
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
return nil
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
if len(headers) == 0 {
return false
}
for key, value := range headers {
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
continue
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
return true
}
}
return false
}
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
var payload any
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
return nil, errUnmarshal
}
return cbor.Marshal(payload)
}
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
if len(raw) == 0 {
return "", nil
}
var payload any
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
return "", errUnmarshal
}
jsonCompatible := cborValueToJSONCompatible(payload)
switch typed := jsonCompatible.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
default:
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
if errMarshal != nil {
return "", errMarshal
}
return string(jsonBytes), nil
}
}
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
func cborValueToJSONCompatible(value any) any {
switch typed := value.(type) {
case map[any]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
}
return out
case map[string]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[key] = cborValueToJSONCompatible(item)
}
return out
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = cborValueToJSONCompatible(item)
}
return out
default:
return typed
}
}
// QuotaDetail represents quota information for a specific resource type
type QuotaDetail struct {
Entitlement float64 `json:"entitlement"`

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -411,6 +412,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
if !auth.LastRefreshedAt.IsZero() {
entry["last_refresh"] = auth.LastRefreshedAt
}
if !auth.NextRetryAfter.IsZero() {
entry["next_retry_after"] = auth.NextRetryAfter
}
if path != "" {
entry["path"] = path
entry["source"] = "file"
@@ -813,6 +817,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -869,11 +954,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1018,6 +1109,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1193,6 +1285,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
ts.ProjectID = strings.Join(projects, ",")
ts.Checked = true
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed")
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Google One auto-discovery returned empty project ID")
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
return
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
@@ -1252,6 +1368,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1397,6 +1514,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1561,6 +1679,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1616,6 +1735,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1692,6 +1812,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -1811,8 +1932,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
@@ -1826,7 +1945,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
@@ -1838,9 +1957,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
@@ -1849,18 +1972,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
}
@@ -1874,7 +2005,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
@@ -2124,7 +2255,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
@@ -2374,6 +2546,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {
@@ -2668,3 +2848,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
return verifier, challenge, nil
}
func (h *Handler) RequestKiloToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing Kilo authentication...")
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
kilocodeAuth := kilo.NewKiloAuth()
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
RegisterOAuthSession(state, "kilo")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", err)
return
}
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
log.Warnf("Failed to fetch profile: %v", err)
profile = &kilo.Profile{Email: status.UserEmail}
}
var orgID string
if len(profile.Orgs) > 0 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
record := &coreauth.Auth{
ID: fileName,
Provider: "kilo",
FileName: fileName,
Storage: ts,
Metadata: map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kilo")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": resp.VerificationURL,
"state": state,
"user_code": resp.Code,
"verification_uri": resp.VerificationURL,
})
}

View File

@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
c.JSON(200, gin.H{})
return
}
cfgCopy := *h.cfg
c.JSON(200, &cfgCopy)
c.JSON(200, new(*h.cfg))
}
type releaseInfo struct {

View File

@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
func (h *Handler) PutAPIKeys(c *gin.Context) {
h.putStringList(c, func(v []string) {
h.cfg.APIKeys = append([]string(nil), v...)
h.cfg.Access.Providers = nil
}, nil)
}
func (h *Handler) PatchAPIKeys(c *gin.Context) {
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
h.patchStringList(c, &h.cfg.APIKeys, func() {})
}
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
}
// gemini-api-key: []GeminiKey
@@ -797,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
// Set to nil instead of deleting the key so that the "explicitly disabled"
// marker survives config reload and prevents SanitizeOAuthModelAlias from
// re-injecting default aliases (fixes #222).
h.cfg.OAuthModelAlias[channel] = nil
h.persist(c)
}

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -15,10 +15,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
if c.Request.Method == http.MethodGet {
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c)
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture request body
var body []byte
if c.Request.Body != nil {
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {

View File

@@ -0,0 +1,138 @@
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
bodyStr := string(w.requestInfo.Body)
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}

View File

@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
m.modelMapper = NewModelMapper(settings.ModelMappings)
// Store initial config for partial reload comparison
settingsCopy := settings
m.lastConfig = &settingsCopy
m.lastConfig = new(settings)
// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)

View File

@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Don't log as error for context canceled - it's usually client closing connection
if errors.Is(err, context.Canceled) {
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
return
} else {
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
}

View File

@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -122,7 +122,7 @@ func (rw *ResponseRewriter) Flush() {
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility

View File

@@ -0,0 +1,110 @@
package amp
import (
"testing"
)
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
result := rw.rewriteModelInResponse(input)
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
result := rw.rewriteModelInResponse(input)
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
result := rw.rewriteModelInResponse(input)
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
result := rw.rewriteModelInResponse(input)
if string(result) != string(input) {
t.Errorf("expected no modification, got %s", string(result))
}
}
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: ""}
input := []byte(`{"model":"gpt-5.3-codex"}`)
result := rw.rewriteModelInResponse(input)
if string(result) != string(input) {
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
}
}
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
if string(result) == string(chunk) {
t.Error("expected response.model to be rewritten in SSE stream")
}
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
t.Errorf("expected rewritten model in output, got %s", string(result))
}
}
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
return true
}
}
return false
}

View File

@@ -52,6 +52,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes
@@ -285,8 +296,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available.
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
@@ -329,6 +341,7 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
@@ -642,6 +655,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
@@ -649,6 +663,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
@@ -683,14 +698,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
if _, err := os.Stat(filePath); err != nil {
if os.IsNotExist(err) {
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
c.AbortWithStatus(http.StatusNotFound)
// Synchronously ensure management.html is available with a detached context.
// Control panel bootstrap should not be canceled by client disconnects.
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
c.AbortWithStatus(http.StatusNotFound)
return
}
} else {
log.WithError(err).Error("failed to stat management control panel asset")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
log.WithError(err).Error("failed to stat management control panel asset")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.File(filePath)
@@ -980,10 +998,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.handlers.UpdateClients(&cfg.SDKConfig)
if !cfg.RemoteManagement.DisableControlPanel {
staticDir := managementasset.StaticDir(s.configFilePath)
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
}
if s.mgmt != nil {
s.mgmt.SetConfig(cfg)
s.mgmt.SetAuthManager(s.handlers.AuthManager)
@@ -1062,14 +1076,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
return
}
switch {
case errors.Is(err, sdkaccess.ErrNoCredentials):
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
case errors.Is(err, sdkaccess.ErrInvalidCredential):
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
default:
statusCode := err.HTTPStatusCode()
if statusCode >= http.StatusInternalServerError {
log.Errorf("authentication middleware error: %v", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
}
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
}
}

View File

@@ -20,7 +20,7 @@ import (
// OAuth configuration constants for Claude/Anthropic
const (
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {

View File

@@ -0,0 +1,44 @@
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
}
// Fetch the GitHub username
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if err != nil {
log.Warnf("copilot: failed to fetch user info: %v", err)
username = "unknown"
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
return &CopilotAuthBundle{
TokenData: tokenData,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
}, nil
}
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
return false, "", nil
}
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
if err != nil {
return false, "", err
}
return true, username, nil
return true, userInfo.Login, nil
}
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
Username: bundle.Username,
Email: bundle.Email,
Name: bundle.Name,
Type: "github-copilot",
}
}

View File

@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("scope", "user:email")
data.Set("scope", "read:user user:email")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
}, nil
}
// FetchUserInfo retrieves the GitHub username for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
// GitHubUserInfo holds GitHub user profile information.
type GitHubUserInfo struct {
// Login is the GitHub username.
Login string
// Email is the primary email address (may be empty if not public).
Email string
// Name is the display name.
Name string
}
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
if accessToken == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
resp, err := c.httpClient.Do(req)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var userInfo struct {
var raw struct {
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
if userInfo.Login == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
if raw.Login == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
}
return userInfo.Login, nil
return GitHubUserInfo{
Login: raw.Login,
Email: raw.Email,
Name: raw.Name,
}, nil
}

View File

@@ -0,0 +1,213 @@
package copilot
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// roundTripFunc lets us inject a custom transport for testing.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
// regardless of the original URL host.
func newTestClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.URL.Scheme = "http"
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
return srv.Client().Transport.RoundTrip(req2)
}),
}
}
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
func TestFetchUserInfo_FullProfile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"login": "octocat",
"email": "octocat@github.com",
"name": "The Octocat",
})
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "octocat" {
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
}
if info.Email != "octocat@github.com" {
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
}
if info.Name != "The Octocat" {
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
}
}
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// GitHub returns null for private emails.
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "privateuser" {
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
}
if info.Email != "" {
t.Errorf("Email: got %q, want empty string", info.Email)
}
if info.Name != "Private User" {
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
}
}
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
func TestFetchUserInfo_EmptyToken(t *testing.T) {
client := &DeviceFlowClient{httpClient: http.DefaultClient}
_, err := client.FetchUserInfo(context.Background(), "")
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
}
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "test-token")
if err == nil {
t.Fatal("expected error for empty login, got nil")
}
}
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
func TestFetchUserInfo_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("expected error for 401 response, got nil")
}
}
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
TokenType: "bearer",
Scope: "read:user user:email",
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
if _, ok := out[key]; !ok {
t.Errorf("expected key %q in JSON output, not found", key)
}
}
if out["email"] != "octocat@github.com" {
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
}
if out["name"] != "The Octocat" {
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
}
}
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
Username: "octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if _, ok := out["email"]; ok {
t.Error("email key should be omitted when empty (omitempty), but was present")
}
if _, ok := out["name"]; ok {
t.Error("name key should be omitted when empty (omitempty), but was present")
}
}
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
bundle := &CopilotAuthBundle{
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if bundle.Email != "octocat@github.com" {
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
}
if bundle.Name != "The Octocat" {
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
}
}
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
func TestGitHubUserInfo_Struct(t *testing.T) {
info := GitHubUserInfo{
Login: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if info.Login == "" || info.Email == "" || info.Name == "" {
t.Error("GitHubUserInfo fields should not be empty")
}
}

View File

@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
ExpiresAt string `json:"expires_at,omitempty"`
// Username is the GitHub username associated with this token.
Username string `json:"username"`
// Email is the GitHub email address associated with this token.
Email string `json:"email,omitempty"`
// Name is the GitHub display name associated with this token.
Name string `json:"name,omitempty"`
// Type indicates the authentication provider type, always "github-copilot" for this storage.
Type string `json:"type"`
}
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
TokenData *CopilotTokenData
// Username is the GitHub username.
Username string
// Email is the GitHub email address.
Email string
// Name is the GitHub display name.
Name string
}
// DeviceCodeResponse represents GitHub's device code response.

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -0,0 +1,168 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
const (
// BaseURL is the base URL for the Kilo AI API.
BaseURL = "https://api.kilo.ai/api"
)
// DeviceAuthResponse represents the response from initiating device flow.
type DeviceAuthResponse struct {
Code string `json:"code"`
VerificationURL string `json:"verificationUrl"`
ExpiresIn int `json:"expiresIn"`
}
// DeviceStatusResponse represents the response when polling for device flow status.
type DeviceStatusResponse struct {
Status string `json:"status"`
Token string `json:"token"`
UserEmail string `json:"userEmail"`
}
// Profile represents the user profile from Kilo AI.
type Profile struct {
Email string `json:"email"`
Orgs []Organization `json:"organizations"`
}
// Organization represents a Kilo AI organization.
type Organization struct {
ID string `json:"id"`
Name string `json:"name"`
}
// Defaults represents default settings for an organization or user.
type Defaults struct {
Model string `json:"model"`
}
// KiloAuth provides methods for handling the Kilo AI authentication flow.
type KiloAuth struct {
client *http.Client
}
// NewKiloAuth creates a new instance of KiloAuth.
func NewKiloAuth() *KiloAuth {
return &KiloAuth{
client: &http.Client{Timeout: 30 * time.Second},
}
}
// InitiateDeviceFlow starts the device authentication flow.
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
}
var data DeviceAuthResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return &data, nil
}
// PollForToken polls for the device flow completion.
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var data DeviceStatusResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
switch data.Status {
case "approved":
return &data, nil
case "denied", "expired":
return nil, fmt.Errorf("device flow %s", data.Status)
case "pending":
continue
default:
return nil, fmt.Errorf("unknown status: %s", data.Status)
}
}
}
}
// GetProfile fetches the user's profile.
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
if err != nil {
return nil, fmt.Errorf("failed to create get profile request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
}
var profile Profile
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
return nil, err
}
return &profile, nil
}
// GetDefaults fetches default settings for an organization.
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
url := BaseURL + "/defaults"
if orgID != "" {
url = BaseURL + "/organizations/" + orgID + "/defaults"
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
}
var defaults Defaults
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
return nil, err
}
return &defaults, nil
}

View File

@@ -0,0 +1,60 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// KiloTokenStorage stores token information for Kilo AI authentication.
type KiloTokenStorage struct {
// Token is the Kilo access token.
Token string `json:"kilocodeToken"`
// OrganizationID is the Kilo organization ID.
OrganizationID string `json:"kilocodeOrganizationId"`
// Model is the default model to use.
Model string `json:"kilocodeModel"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Type indicates the authentication provider type, always "kilo" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "kilo"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("failed to close file: %v", errClose)
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used to persist Kilo credentials.
func CredentialFileName(email string) string {
return fmt.Sprintf("kilo-%s.json", email)
}

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -7,10 +7,13 @@ import (
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
@@ -47,7 +50,7 @@ type KiroTokenData struct {
Email string `json:"email,omitempty"`
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
StartURL string `json:"startUrl,omitempty"`
// Region is the AWS region for IDC authentication (only for IDC auth method)
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
}
@@ -520,3 +523,159 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string {
// Priority 3: Fallback to authMethod only with sequence
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
}
// DefaultKiroRegion is the fallback region when none is specified.
const DefaultKiroRegion = "us-east-1"
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
func GetCodeWhispererLegacyEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://codewhisperer." + region + ".amazonaws.com"
}
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
type ProfileARN struct {
// Raw is the original ARN string
Raw string
// Partition is the AWS partition (aws)
Partition string
// Service is the AWS service name (codewhisperer)
Service string
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
Region string
// AccountID is the AWS account ID
AccountID string
// ResourceType is the resource type (profile)
ResourceType string
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
ResourceID string
}
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
func ParseProfileARN(arn string) *ProfileARN {
if arn == "" {
return nil
}
// ARN format: arn:partition:service:region:account-id:resource
// Minimum 6 parts separated by ":"
parts := strings.Split(arn, ":")
if len(parts) < 6 {
log.Warnf("invalid ARN format: %s", arn)
return nil
}
// Validate ARN prefix
if parts[0] != "arn" {
return nil
}
// Validate partition
partition := parts[1]
if partition == "" {
return nil
}
// Validate service is codewhisperer
service := parts[2]
if service != "codewhisperer" {
return nil
}
// Validate region format (must contain "-")
region := parts[3]
if region == "" || !strings.Contains(region, "-") {
return nil
}
// Account ID
accountID := parts[4]
// Parse resource (format: resource-type/resource-id)
// Join remaining parts in case resource contains ":"
resource := strings.Join(parts[5:], ":")
resourceType := ""
resourceID := ""
if idx := strings.Index(resource, "/"); idx > 0 {
resourceType = resource[:idx]
resourceID = resource[idx+1:]
} else {
resourceType = resource
}
return &ProfileARN{
Raw: arn,
Partition: partition,
Service: service,
Region: region,
AccountID: accountID,
ResourceType: resourceType,
ResourceID: resourceID,
}
}
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
// If region is empty, defaults to us-east-1.
func GetKiroAPIEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://q." + region + ".amazonaws.com"
}
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
// Returns default us-east-1 endpoint if region cannot be extracted.
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
region := ExtractRegionFromProfileArn(profileArn)
return GetKiroAPIEndpoint(region)
}
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
// Returns empty string if ARN is invalid or region cannot be extracted.
func ExtractRegionFromProfileArn(profileArn string) string {
parsed := ParseProfileARN(profileArn)
if parsed == nil {
return ""
}
return parsed.Region
}
// ExtractRegionFromMetadata extracts API region from auth metadata.
// Priority: api_region > profile_arn > DefaultKiroRegion
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
if metadata == nil {
return DefaultKiroRegion
}
// Priority 1: Explicit api_region override
if r, ok := metadata["api_region"].(string); ok && r != "" {
return r
}
// Priority 2: Extract from ProfileARN
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
return region
}
}
return DefaultKiroRegion
}
func buildURL(endpoint, path string, queryParams map[string]string) string {
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
if len(queryParams) > 0 {
values := url.Values{}
for key, value := range queryParams {
if value == "" {
continue
}
values.Set(key, value)
}
if encoded := values.Encode(); encoded != "" {
fullURL = fullURL + "?" + encoded
}
}
return fullURL
}

View File

@@ -19,15 +19,8 @@ import (
)
const (
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
// for their respective API operations.
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
pathGetUsageLimits = "getUsageLimits"
pathListAvailableModels = "ListAvailableModels"
)
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
@@ -35,7 +28,6 @@ const (
// and communicating with the CodeWhisperer API.
type KiroAuth struct {
httpClient *http.Client
endpoint string
}
// NewKiroAuth creates a new Kiro authentication service.
@@ -49,7 +41,6 @@ type KiroAuth struct {
func NewKiroAuth(cfg *config.Config) *KiroAuth {
return &KiroAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
@@ -110,33 +101,30 @@ func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
return time.Now().After(expiresAt)
}
// makeRequest sends a request to the CodeWhisperer API.
// This is an internal method for making authenticated API calls.
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
//
// Parameters:
// - ctx: The context for the request
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
// - accessToken: The OAuth access token
// - payload: The request payload
// - path: The API path (e.g., "getUsageLimits")
// - tokenData: The token data containing access token, refresh token, and profile ARN
// - queryParams: Query parameters to add to the URL
//
// Returns:
// - []byte: The response body
// - error: An error if the request fails
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
// Get endpoint from profileArn (defaults to us-east-1 if empty)
profileArn := queryParams["profileArn"]
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
url := buildURL(endpoint, path, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", target)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := k.httpClient.Do(req)
if err != nil {
@@ -171,13 +159,13 @@ func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken s
// - *KiroUsageInfo: The usage information
// - error: An error if the request fails
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
if err != nil {
return nil, err
}
@@ -221,12 +209,12 @@ func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData)
// - []*KiroModel: The list of available models
// - error: An error if the request fails
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
}
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package kiro
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
)
@@ -217,7 +218,8 @@ func TestGenerateTokenFileName(t *testing.T) {
tests := []struct {
name string
tokenData *KiroTokenData
expected string
exact string // exact match (for cases with email)
prefix string // prefix match (for cases without email, where sequence is appended)
}{
{
name: "IDC with email",
@@ -226,7 +228,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user@example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-example-com.json",
exact: "kiro-idc-user-example-com.json",
},
{
name: "IDC without email but with startUrl",
@@ -235,7 +237,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-d-1234567890.json",
prefix: "kiro-idc-d-1234567890-",
},
{
name: "IDC with company name in startUrl",
@@ -244,7 +246,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://my-company.awsapps.com/start",
},
expected: "kiro-idc-my-company.json",
prefix: "kiro-idc-my-company-",
},
{
name: "IDC without email and without startUrl",
@@ -253,7 +255,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "",
},
expected: "kiro-idc.json",
prefix: "kiro-idc-",
},
{
name: "Builder ID with email",
@@ -262,7 +264,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user@gmail.com",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id-user-gmail-com.json",
exact: "kiro-builder-id-user-gmail-com.json",
},
{
name: "Builder ID without email",
@@ -271,7 +273,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id.json",
prefix: "kiro-builder-id-",
},
{
name: "Social auth with email",
@@ -279,7 +281,7 @@ func TestGenerateTokenFileName(t *testing.T) {
AuthMethod: "google",
Email: "user@gmail.com",
},
expected: "kiro-google-user-gmail-com.json",
exact: "kiro-google-user-gmail-com.json",
},
{
name: "Empty auth method",
@@ -287,7 +289,7 @@ func TestGenerateTokenFileName(t *testing.T) {
AuthMethod: "",
Email: "",
},
expected: "kiro-unknown.json",
prefix: "kiro-unknown-",
},
{
name: "Email with special characters",
@@ -296,16 +298,454 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user.name+tag@sub.example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-name+tag-sub-example-com.json",
exact: "kiro-idc-user-name+tag-sub-example-com.json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateTokenFileName(tt.tokenData)
if result != tt.expected {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
if tt.exact != "" {
if result != tt.exact {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
}
} else if tt.prefix != "" {
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
}
}
})
}
}
func TestParseProfileARN(t *testing.T) {
tests := []struct {
name string
arn string
expected *ProfileARN
}{
{
name: "Empty ARN",
arn: "",
expected: nil,
},
{
name: "Invalid format - too few parts",
arn: "arn:aws:codewhisperer",
expected: nil,
},
{
name: "Invalid prefix - not arn",
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Invalid service - not codewhisperer",
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: nil,
},
{
name: "Invalid region - no hyphen",
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty partition",
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty region",
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
expected: nil,
},
{
name: "Valid ARN - us-east-1",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABCDEFGHIJKL",
},
},
{
name: "Valid ARN - ap-southeast-1",
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
Partition: "aws",
Service: "codewhisperer",
Region: "ap-southeast-1",
AccountID: "987654321098",
ResourceType: "profile",
ResourceID: "ZYXWVUTSRQ",
},
},
{
name: "Valid ARN - eu-west-1",
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
Partition: "aws",
Service: "codewhisperer",
Region: "eu-west-1",
AccountID: "111222333444",
ResourceType: "profile",
ResourceID: "PROFILE123",
},
},
{
name: "Valid ARN - aws-cn partition",
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
expected: &ProfileARN{
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
Partition: "aws-cn",
Service: "codewhisperer",
Region: "cn-north-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "CHINAID",
},
},
{
name: "Valid ARN - resource without slash",
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
Partition: "aws",
Service: "codewhisperer",
Region: "us-west-2",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "",
},
},
{
name: "Valid ARN - resource with colon",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABC:extra",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseProfileARN(tt.arn)
if tt.expected == nil {
if result != nil {
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
}
return
}
if result == nil {
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
return
}
if result.Raw != tt.expected.Raw {
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
}
if result.Partition != tt.expected.Partition {
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
}
if result.Service != tt.expected.Service {
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
}
if result.Region != tt.expected.Region {
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
}
if result.AccountID != tt.expected.AccountID {
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
}
if result.ResourceType != tt.expected.ResourceType {
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
}
if result.ResourceID != tt.expected.ResourceID {
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
}
})
}
}
func TestExtractRegionFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN",
profileArn: "",
expected: "",
},
{
name: "Invalid ARN",
profileArn: "invalid-arn",
expected: "",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "us-east-1",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "ap-southeast-1",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "eu-central-1",
},
{
name: "Non-codewhisperer ARN",
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://q.us-west-2.amazonaws.com",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "eu-west-1",
region: "eu-west-1",
expected: "https://q.eu-west-1.amazonaws.com",
},
{
name: "cn-north-1",
region: "cn-north-1",
expected: "https://q.cn-north-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN - defaults to us-east-1",
profileArn: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Invalid ARN - defaults to us-east-1",
profileArn: "invalid-arn",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "https://q.eu-central-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://codewhisperer.us-west-2.amazonaws.com",
},
{
name: "ap-northeast-1",
region: "ap-northeast-1",
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetCodeWhispererLegacyEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestExtractRegionFromMetadata(t *testing.T) {
tests := []struct {
name string
metadata map[string]interface{}
expected string
}{
{
name: "Nil metadata - defaults to us-east-1",
metadata: nil,
expected: "us-east-1",
},
{
name: "Empty metadata - defaults to us-east-1",
metadata: map[string]interface{}{},
expected: "us-east-1",
},
{
name: "Priority 1: api_region override",
metadata: map[string]interface{}{
"api_region": "eu-west-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
expected: "eu-west-1",
},
{
name: "Priority 2: profile_arn when api_region is empty",
metadata: map[string]interface{}{
"api_region": "",
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
expected: "ap-southeast-1",
},
{
name: "Priority 2: profile_arn when api_region is missing",
metadata: map[string]interface{}{
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
},
expected: "eu-central-1",
},
{
name: "Fallback: default when profile_arn is invalid",
metadata: map[string]interface{}{
"profile_arn": "invalid-arn",
},
expected: "us-east-1",
},
{
name: "Fallback: default when profile_arn is empty",
metadata: map[string]interface{}{
"profile_arn": "",
},
expected: "us-east-1",
},
{
name: "OIDC region is NOT used for API region",
metadata: map[string]interface{}{
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-east-1",
},
{
name: "api_region takes precedence over OIDC region",
metadata: map[string]interface{}{
"api_region": "us-west-2",
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-west-2",
},
{
name: "Non-string api_region is ignored",
metadata: map[string]interface{}{
"api_region": 123, // wrong type
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
},
expected: "ap-south-1",
},
{
name: "Non-string profile_arn is ignored",
metadata: map[string]interface{}{
"profile_arn": 123, // wrong type
},
expected: "us-east-1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromMetadata(tt.metadata)
if result != tt.expected {
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
}
})
}
}

View File

@@ -9,30 +9,23 @@ import (
"net/http"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
kiroVersion = "0.6.18"
)
// CodeWhispererClient handles CodeWhisperer API calls.
type CodeWhispererClient struct {
httpClient *http.Client
machineID string
}
// UsageLimitsResponse represents the getUsageLimits API response.
type UsageLimitsResponse struct {
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
UserInfo *UserInfo `json:"userInfo,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
UserInfo *UserInfo `json:"userInfo,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
}
// UserInfo contains user information from the API.
@@ -49,13 +42,13 @@ type SubscriptionInfo struct {
// UsageBreakdown contains usage details.
type UsageBreakdown struct {
UsageLimit *int `json:"usageLimit,omitempty"`
CurrentUsage *int `json:"currentUsage,omitempty"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
DisplayName string `json:"displayName,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
UsageLimit *int `json:"usageLimit,omitempty"`
CurrentUsage *int `json:"currentUsage,omitempty"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
DisplayName string `json:"displayName,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
}
// NewCodeWhispererClient creates a new CodeWhisperer client.
@@ -64,40 +57,34 @@ func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhisperer
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
if machineID == "" {
machineID = uuid.New().String()
}
return &CodeWhispererClient{
httpClient: client,
machineID: machineID,
}
}
// generateInvocationID generates a unique invocation ID.
func generateInvocationID() string {
return uuid.New().String()
}
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
// This is the recommended way to get user email after login.
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
queryParams := map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
}
// Determine endpoint based on profileArn region
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
if profileArn != "" {
queryParams["profileArn"] = profileArn
} else {
queryParams["isEmailRequired"] = "true"
}
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers to match Kiro IDE
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
req.Header.Set("User-Agent", userAgent)
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
req.Header.Set("Connection", "close")
accountKey := GetAccountKey(clientID, refreshToken)
setRuntimeHeaders(req, accessToken, accountKey)
log.Debugf("codewhisperer: GET %s", url)
@@ -128,8 +115,8 @@ func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken st
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
// This is more reliable than JWT parsing as it uses the official API.
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
resp, err := c.GetUsageLimits(ctx, accessToken)
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
if err != nil {
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
return ""
@@ -146,10 +133,10 @@ func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessT
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
// Method 1: Try CodeWhisperer API (most reliable)
cwClient := NewCodeWhispererClient(cfg, "")
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
if email != "" {
return email
}

View File

@@ -2,77 +2,105 @@ package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"runtime"
"slices"
"sync"
"time"
"github.com/google/uuid"
)
// Fingerprint 多维度指纹信息
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
type Fingerprint struct {
SDKVersion string // 1.0.20-1.0.27
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
RuntimeSDKVersion string // 1.0.x (runtime API)
StreamingSDKVersion string // 1.0.x (streaming API)
OSType string // darwin/windows/linux
OSVersion string // 10.0.22621
NodeVersion string // 18.x/20.x/22.x
KiroVersion string // 0.3.x-0.8.x
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string // SHA256
AcceptLanguage string
ScreenResolution string // 1920x1080
ColorDepth int // 24
HardwareConcurrency int // CPU 核心数
TimezoneOffset int
}
// FingerprintManager 指纹管理器
// FingerprintConfig holds external fingerprint overrides.
type FingerprintConfig struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
// FingerprintManager manages per-account fingerprint generation and caching.
type FingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
rng *rand.Rand
config *FingerprintConfig // External config (Optional)
}
var (
sdkVersions = []string{
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
// SDK versions
oidcSDKVersions = []string{
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
}
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
runtimeSDKVersions = []string{"1.0.0"}
// SDKVersions for generateAssistantResponse (streaming API)
streamingSDKVersions = []string{"1.0.27"}
// Valid OS types
osTypes = []string{"darwin", "windows", "linux"}
// OS versions
osVersions = map[string][]string{
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
}
// Node versions
nodeVersions = []string{
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
"20.18.0", "20.17.0", "20.16.0",
}
// Kiro IDE versions
kiroVersions = []string{
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
"0.10.32", "0.10.16", "0.10.10",
"0.9.47", "0.9.40", "0.9.2",
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
}
acceptLanguages = []string{
"en-US,en;q=0.9",
"en-GB,en;q=0.9",
"zh-CN,zh;q=0.9,en;q=0.8",
"zh-TW,zh;q=0.9,en;q=0.8",
"ja-JP,ja;q=0.9,en;q=0.8",
"ko-KR,ko;q=0.9,en;q=0.8",
"de-DE,de;q=0.9,en;q=0.8",
"fr-FR,fr;q=0.9,en;q=0.8",
}
screenResolutions = []string{
"1920x1080", "2560x1440", "3840x2160",
"1366x768", "1440x900", "1680x1050",
"2560x1600", "3440x1440",
}
colorDepths = []int{24, 32}
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
// Global singleton
globalFingerprintManager *FingerprintManager
globalFingerprintManagerOnce sync.Once
)
// NewFingerprintManager 创建指纹管理器
func GlobalFingerprintManager() *FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = NewFingerprintManager()
})
return globalFingerprintManager
}
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
GlobalFingerprintManager().SetConfig(cfg)
}
// SetConfig applies the config and clears the fingerprint cache.
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
fm.mu.Lock()
defer fm.mu.Unlock()
fm.config = cfg
// Clear cached fingerprints so they regenerate with the new config
fm.fingerprints = make(map[string]*Fingerprint)
}
func NewFingerprintManager() *FingerprintManager {
return &FingerprintManager{
fingerprints: make(map[string]*Fingerprint),
@@ -80,7 +108,7 @@ func NewFingerprintManager() *FingerprintManager {
}
}
// GetFingerprint 获取或生成 Token 关联的指纹
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
fm.mu.RLock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
@@ -101,97 +129,150 @@ func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
return fp
}
// generateFingerprint 生成新的指纹
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
osType := fm.randomChoice(osTypes)
osVersion := fm.randomChoice(osVersions[osType])
kiroVersion := fm.randomChoice(kiroVersions)
if fm.config != nil {
return fm.generateFromConfig(tokenKey)
}
return fm.generateRandom(tokenKey)
}
fp := &Fingerprint{
SDKVersion: fm.randomChoice(sdkVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: fm.randomChoice(nodeVersions),
KiroVersion: kiroVersion,
AcceptLanguage: fm.randomChoice(acceptLanguages),
ScreenResolution: fm.randomChoice(screenResolutions),
ColorDepth: fm.randomIntChoice(colorDepths),
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
// generateFromConfig uses config values, falling back to random for empty fields.
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
cfg := fm.config
// Helper: config value or random selection
configOrRandom := func(configVal string, choices []string) string {
if configVal != "" {
return configVal
}
return choices[fm.rng.Intn(len(choices))]
}
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
return fp
osType := cfg.OSType
if osType == "" {
osType = runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[fm.rng.Intn(len(osTypes))]
}
}
osVersion := cfg.OSVersion
if osVersion == "" {
if versions, ok := osVersions[osType]; ok {
osVersion = versions[fm.rng.Intn(len(versions))]
}
}
kiroHash := cfg.KiroHash
if kiroHash == "" {
hash := sha256.Sum256([]byte(tokenKey))
kiroHash = hex.EncodeToString(hash[:])
}
return &Fingerprint{
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
KiroHash: kiroHash,
}
}
// generateKiroHash 生成 Kiro Hash
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
// Use accountKey hash as seed for deterministic random selection
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
return &Fingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersion,
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: hex.EncodeToString(hash[:]),
}
}
// randomChoice 随机选择字符串
func (fm *FingerprintManager) randomChoice(choices []string) string {
return choices[fm.rng.Intn(len(choices))]
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
func GenerateAccountKey(seed string) string {
hash := sha256.Sum256([]byte(seed))
return hex.EncodeToString(hash[:8])
}
// randomIntChoice 随机选择整数
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
return choices[fm.rng.Intn(len(choices))]
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
func GetAccountKey(clientID, refreshToken string) string {
// 1. Prefer ClientID
if clientID != "" {
return GenerateAccountKey(clientID)
}
// 2. Fallback to RefreshToken
if refreshToken != "" {
return GenerateAccountKey(refreshToken)
}
// 3. Random fallback
return GenerateAccountKey(uuid.New().String())
}
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
req.Header.Set("Accept-Language", fp.AcceptLanguage)
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
}
// RemoveFingerprint 移除 Token 关联的指纹
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
fm.mu.Lock()
defer fm.mu.Unlock()
delete(fm.fingerprints, tokenKey)
}
// Count 返回当前管理的指纹数量
func (fm *FingerprintManager) Count() int {
fm.mu.RLock()
defer fm.mu.RUnlock()
return len(fm.fingerprints)
}
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildAmzUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func SetOIDCHeaders(req *http.Request) {
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
}
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
fp.KiroVersion, machineID))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
}

View File

@@ -2,6 +2,8 @@ package kiro
import (
"net/http"
"runtime"
"strings"
"sync"
"testing"
)
@@ -26,8 +28,14 @@ func TestGetFingerprint_NewToken(t *testing.T) {
if fp == nil {
t.Fatal("expected non-nil Fingerprint")
}
if fp.SDKVersion == "" {
t.Error("expected non-empty SDKVersion")
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.RuntimeSDKVersion == "" {
t.Error("expected non-empty RuntimeSDKVersion")
}
if fp.StreamingSDKVersion == "" {
t.Error("expected non-empty StreamingSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
@@ -44,18 +52,6 @@ func TestGetFingerprint_NewToken(t *testing.T) {
if fp.KiroHash == "" {
t.Error("expected non-empty KiroHash")
}
if fp.AcceptLanguage == "" {
t.Error("expected non-empty AcceptLanguage")
}
if fp.ScreenResolution == "" {
t.Error("expected non-empty ScreenResolution")
}
if fp.ColorDepth == 0 {
t.Error("expected non-zero ColorDepth")
}
if fp.HardwareConcurrency == 0 {
t.Error("expected non-zero HardwareConcurrency")
}
}
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
@@ -78,72 +74,18 @@ func TestGetFingerprint_DifferentTokens(t *testing.T) {
}
}
func TestRemoveFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fm.GetFingerprint("token1")
if fm.Count() != 1 {
t.Fatalf("expected count 1, got %d", fm.Count())
}
fm.RemoveFingerprint("token1")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestRemoveFingerprint_NonExistent(t *testing.T) {
fm := NewFingerprintManager()
fm.RemoveFingerprint("nonexistent")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestCount(t *testing.T) {
fm := NewFingerprintManager()
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
fm.GetFingerprint("token1")
fm.GetFingerprint("token2")
fm.GetFingerprint("token3")
if fm.Count() != 3 {
t.Errorf("expected count 3, got %d", fm.Count())
}
}
func TestApplyToRequest(t *testing.T) {
func TestBuildUserAgent(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
ua := fp.BuildUserAgent()
if ua == "" {
t.Error("expected non-empty User-Agent")
}
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
t.Error("X-Kiro-SDK-Version header mismatch")
}
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
t.Error("X-Kiro-OS-Type header mismatch")
}
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
t.Error("X-Kiro-OS-Version header mismatch")
}
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
t.Error("X-Kiro-Node-Version header mismatch")
}
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
t.Error("X-Kiro-Version header mismatch")
}
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
t.Error("X-Kiro-Hash header mismatch")
}
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
t.Error("Accept-Language header mismatch")
}
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
t.Error("X-Screen-Resolution header mismatch")
amzUA := fp.BuildAmzUserAgent()
if amzUA == "" {
t.Error("expected non-empty X-Amz-User-Agent")
}
}
@@ -166,6 +108,33 @@ func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
}
}
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
fm := NewFingerprintManager()
// Set config with empty OSType to trigger runtime.GOOS fallback
fm.SetConfig(&FingerprintConfig{
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
})
fp := fm.GetFingerprint("test-token")
// Expected OS type based on runtime.GOOS mapping
var expectedOS string
switch runtime.GOOS {
case "darwin":
expectedOS = "darwin"
case "windows":
expectedOS = "windows"
default:
expectedOS = "linux"
}
if fp.OSType != expectedOS {
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
expectedOS, runtime.GOOS, fp.OSType)
}
}
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
fm := NewFingerprintManager()
const numGoroutines = 100
@@ -174,22 +143,18 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
for i := range numGoroutines {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
for j := range numOperations {
tokenKey := "token" + string(rune('a'+id%26))
switch j % 4 {
switch j % 2 {
case 0:
fm.GetFingerprint(tokenKey)
case 1:
fm.Count()
case 2:
fp := fm.GetFingerprint(tokenKey)
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
case 3:
fm.RemoveFingerprint(tokenKey)
_ = fp.BuildUserAgent()
_ = fp.BuildAmzUserAgent()
}
}
}(i)
@@ -198,16 +163,20 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
wg.Wait()
}
func TestKiroHashUniqueness(t *testing.T) {
func TestKiroHashStability(t *testing.T) {
fm := NewFingerprintManager()
hashes := make(map[string]bool)
for i := 0; i < 100; i++ {
fp := fm.GetFingerprint("token" + string(rune(i)))
if hashes[fp.KiroHash] {
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
}
hashes[fp.KiroHash] = true
// Same token should always return same hash
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1.KiroHash != fp2.KiroHash {
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
}
// Different tokens should have different hashes
fp3 := fm.GetFingerprint("token2")
if fp1.KiroHash == fp3.KiroHash {
t.Errorf("different tokens should have different hashes")
}
}
@@ -220,8 +189,590 @@ func TestKiroHashFormat(t *testing.T) {
}
for _, c := range fp.KiroHash {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character in KiroHash: %c", c)
}
}
}
func TestGlobalFingerprintManager(t *testing.T) {
fm1 := GlobalFingerprintManager()
fm2 := GlobalFingerprintManager()
if fm1 == nil {
t.Fatal("expected non-nil GlobalFingerprintManager")
}
if fm1 != fm2 {
t.Error("expected GlobalFingerprintManager to return same instance")
}
}
func TestSetOIDCHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
SetOIDCHeaders(req)
if req.Header.Get("Content-Type") != "application/json" {
t.Error("expected Content-Type header to be set")
}
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/sso-oidc") {
t.Errorf("User-Agent should contain api name: %s", ua)
}
if req.Header.Get("amz-sdk-invocation-id") == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestBuildURL(t *testing.T) {
tests := []struct {
name string
endpoint string
path string
queryParams map[string]string
want string
wantContains []string
}{
{
name: "no query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: nil,
want: "https://api.example.com/getUsageLimits",
},
{
name: "empty query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{},
want: "https://api.example.com/getUsageLimits",
},
{
name: "single query param",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
{
name: "multiple query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
},
wantContains: []string{
"https://api.example.com/getUsageLimits?",
"origin=AI_EDITOR",
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
"resourceType=AGENTIC_REQUEST",
},
},
{
name: "omit empty params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"profileArn": "",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
if tt.want != "" {
if got != tt.want {
t.Errorf("buildURL() = %v, want %v", got, tt.want)
}
}
if tt.wantContains != nil {
for _, substr := range tt.wantContains {
if !strings.Contains(got, substr) {
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
}
}
}
})
}
}
func TestBuildUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
ua := fp.BuildUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"ua/2.1",
"os/",
"lang/js",
"md/nodejs#",
"api/codewhispererstreaming#",
"m/E",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(ua, part) {
t.Errorf("User-Agent missing required part %q: %s", part, ua)
}
}
}
func TestBuildAmzUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
amzUA := fp.BuildAmzUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(amzUA, part) {
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
}
}
// Amz-User-Agent should be shorter than User-Agent
ua := fp.BuildUserAgent()
if len(amzUA) >= len(ua) {
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
}
}
func TestSetRuntimeHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
accessToken := "test-access-token-1234567890"
clientID := "test-client-id-12345"
accountKey := GenerateAccountKey(clientID)
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
setRuntimeHeaders(req, accessToken, accountKey)
// Check Authorization header
if req.Header.Get("Authorization") != "Bearer "+accessToken {
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
}
// Check x-amz-user-agent header
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE-") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
if !strings.Contains(amzUA, machineID) {
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
}
// Check User-Agent header
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/codewhispererruntime#") {
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
}
if !strings.Contains(ua, "m/N,E") {
t.Errorf("User-Agent should contain m/N,E: %s", ua)
}
// Check amz-sdk-invocation-id (should be a UUID)
invocationID := req.Header.Get("amz-sdk-invocation-id")
if invocationID == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if len(invocationID) != 36 {
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
}
// Check amz-sdk-request
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestSDKVersionsAreValid(t *testing.T) {
// Verify all OIDC SDK versions match expected format (3.xxx.x)
for _, v := range oidcSDKVersions {
if !strings.HasPrefix(v, "3.") {
t.Errorf("OIDC SDK version should start with 3.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
}
}
for _, v := range runtimeSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
}
}
for _, v := range streamingSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
}
}
}
func TestKiroVersionsAreValid(t *testing.T) {
// Verify all Kiro versions match expected format (0.x.xxx)
for _, v := range kiroVersions {
if !strings.HasPrefix(v, "0.") {
t.Errorf("Kiro version should start with 0.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Kiro version should have 3 parts: %s", v)
}
}
}
func TestNodeVersionsAreValid(t *testing.T) {
// Verify all Node versions match expected format (xx.xx.x)
for _, v := range nodeVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Node version should have 3 parts: %s", v)
}
// Should be Node 20.x or 22.x
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
}
}
}
func TestFingerprintManager_SetConfig(t *testing.T) {
fm := NewFingerprintManager()
// Without config, should generate random fingerprint
fp1 := fm.GetFingerprint("token1")
if fp1 == nil {
t.Fatal("expected non-nil fingerprint")
}
// Set config with all fields
cfg := &FingerprintConfig{
OIDCSDKVersion: "3.999.0",
RuntimeSDKVersion: "9.9.9",
StreamingSDKVersion: "8.8.8",
OSType: "darwin",
OSVersion: "99.0.0",
NodeVersion: "99.99.99",
KiroVersion: "9.9.999",
KiroHash: "customhash123",
}
fm.SetConfig(cfg)
// After setting config, should use config values
fp2 := fm.GetFingerprint("token2")
if fp2.OIDCSDKVersion != "3.999.0" {
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
}
if fp2.RuntimeSDKVersion != "9.9.9" {
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
}
if fp2.StreamingSDKVersion != "8.8.8" {
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
}
if fp2.OSType != "darwin" {
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
}
if fp2.OSVersion != "99.0.0" {
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
}
if fp2.NodeVersion != "99.99.99" {
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
}
if fp2.KiroVersion != "9.9.999" {
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
}
if fp2.KiroHash != "customhash123" {
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
}
}
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
fm := NewFingerprintManager()
// Set config with only some fields
cfg := &FingerprintConfig{
KiroVersion: "1.2.345",
KiroHash: "myhash",
// Other fields empty - should use random
}
fm.SetConfig(cfg)
fp := fm.GetFingerprint("token1")
// Configured fields should use config values
if fp.KiroVersion != "1.2.345" {
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
}
if fp.KiroHash != "myhash" {
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
}
// Empty fields should be randomly selected (non-empty)
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
}
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
fm := NewFingerprintManager()
// Get fingerprint before config
fp1 := fm.GetFingerprint("token1")
originalHash := fp1.KiroHash
// Set config
cfg := &FingerprintConfig{
KiroHash: "newcustomhash",
}
fm.SetConfig(cfg)
// Same token should now return different fingerprint (cache cleared)
fp2 := fm.GetFingerprint("token1")
if fp2.KiroHash == originalHash {
t.Error("expected cache to be cleared after SetConfig")
}
if fp2.KiroHash != "newcustomhash" {
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
}
}
func TestGenerateAccountKey(t *testing.T) {
tests := []struct {
name string
seed string
check func(t *testing.T, result string)
}{
{
name: "Empty seed",
seed: "",
check: func(t *testing.T, result string) {
if result == "" {
t.Error("expected non-empty result for empty seed")
}
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
},
},
{
name: "Simple seed",
seed: "test-client-id",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
// Verify it's valid hex
for _, c := range result {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character: %c", c)
}
}
},
},
{
name: "Same seed produces same result",
seed: "deterministic-seed",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("deterministic-seed")
if result != result2 {
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
}
},
},
{
name: "Different seeds produce different results",
seed: "seed-one",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("seed-two")
if result == result2 {
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateAccountKey(tt.seed)
tt.check(t, result)
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
clientID string
refreshToken string
check func(t *testing.T, result string)
}{
{
name: "Priority 1: clientID when both provided",
clientID: "client-id-123",
refreshToken: "refresh-token-456",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("client-id-123")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 2: refreshToken when clientID is empty",
clientID: "",
refreshToken: "refresh-token-789",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("refresh-token-789")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 3: random when both empty",
clientID: "",
refreshToken: "",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
// Should be different each time (random UUID)
result2 := GetAccountKey("", "")
if result == result2 {
t.Log("warning: random keys are the same (possible but unlikely)")
}
},
},
{
name: "clientID only",
clientID: "solo-client-id",
refreshToken: "",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-client-id")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "refreshToken only",
clientID: "",
refreshToken: "solo-refresh-token",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-refresh-token")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetAccountKey(tt.clientID, tt.refreshToken)
tt.check(t, result)
})
}
}
func TestGetAccountKey_Deterministic(t *testing.T) {
// Verify that GetAccountKey produces deterministic results for same inputs
clientID := "test-client-id-abc"
refreshToken := "test-refresh-token-xyz"
// Call multiple times with same inputs
results := make([]string, 10)
for i := range 10 {
results[i] = GetAccountKey(clientID, refreshToken)
}
// All results should be identical
for i := 1; i < 10; i++ {
if results[i] != results[0] {
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
}
}
}
func TestFingerprintDeterministic(t *testing.T) {
// Verify that fingerprints are deterministic based on accountKey
fm := NewFingerprintManager()
accountKey := GenerateAccountKey("test-client-id")
// Get fingerprint multiple times
fp1 := fm.GetFingerprint(accountKey)
fp2 := fm.GetFingerprint(accountKey)
// Should be the same pointer (cached)
if fp1 != fp2 {
t.Error("expected same fingerprint pointer for same key")
}
// Create new manager and verify same values
fm2 := NewFingerprintManager()
fp3 := fm2.GetFingerprint(accountKey)
// Values should be identical (deterministic generation)
if fp1.KiroHash != fp3.KiroHash {
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
}
if fp1.OSType != fp3.OSType {
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
}
if fp1.OSVersion != fp3.OSVersion {
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
}
if fp1.KiroVersion != fp3.KiroVersion {
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
}
if fp1.NodeVersion != fp3.NodeVersion {
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
}
}

View File

@@ -23,10 +23,10 @@ import (
const (
// Kiro auth endpoint
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// Default callback port
defaultCallbackPort = 9876
// Auth timeout
authTimeout = 10 * time.Minute
)
@@ -41,8 +41,10 @@ type KiroTokenResponse struct {
// KiroOAuth handles the OAuth flow for Kiro authentication.
type KiroOAuth struct {
httpClient *http.Client
cfg *config.Config
httpClient *http.Client
cfg *config.Config
machineID string
kiroVersion string
}
// NewKiroOAuth creates a new Kiro OAuth handler.
@@ -51,9 +53,12 @@ func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &KiroOAuth{
httpClient: client,
cfg: cfg,
httpClient: client,
cfg: cfg,
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
@@ -190,7 +195,8 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -256,11 +262,8 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
}
req.Header.Set("Content-Type", "application/json")
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
// This helps avoid 403 errors from server-side User-Agent validation
userAgent := buildKiroUserAgent(tokenKey)
req.Header.Set("User-Agent", userAgent)
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -301,19 +304,6 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
}, nil
}
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
// Otherwise generates a simple KiroIDE User-Agent.
func buildKiroUserAgent(tokenKey string) string {
if tokenKey != "" {
fm := NewFingerprintManager()
fp := fm.GetFingerprint(tokenKey)
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
}
// Default KiroIDE User-Agent matching kiro-openai-gateway format
return "KiroIDE-0.7.45-cli-proxy-api"
}
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {

View File

@@ -35,35 +35,35 @@ const (
)
type webAuthSession struct {
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
}
type OAuthWebHandler struct {
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
}
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
@@ -104,7 +104,7 @@ func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
method := c.Query("method")
if method == "" {
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
return
@@ -138,7 +138,7 @@ func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
}
socialClient := NewSocialAuthClient(h.cfg)
var provider string
if method == "google" {
provider = string(ProviderGoogle)
@@ -373,22 +373,28 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
// Fetch profileArn for IDC
var profileArn string
if session.authMethod == "idc" {
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
}
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
h.mu.Lock()
session.status = statusSuccess
@@ -442,7 +448,7 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
fileName := GenerateTokenFileName(tokenData)
authFilePath := filepath.Join(authDir, fileName)
// Convert to storage format and save
storage := &KiroTokenStorage{
Type: "kiro",
@@ -459,12 +465,12 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
if err := storage.SaveTokenToFile(authFilePath); err != nil {
log.Errorf("OAuth Web: failed to save token to file: %v", err)
return
}
log.Infof("OAuth Web: token saved to %s", authFilePath)
}

View File

@@ -10,14 +10,14 @@ import (
log "github.com/sirupsen/logrus"
)
// RefreshManager 是后台刷新器的单例管理器
// RefreshManager is a singleton manager for background token refreshing.
type RefreshManager struct {
mu sync.Mutex
refresher *BackgroundRefresher
ctx context.Context
cancel context.CancelFunc
started bool
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
}
var (
@@ -25,7 +25,7 @@ var (
managerOnce sync.Once
)
// GetRefreshManager 获取全局刷新管理器实例
// GetRefreshManager returns the global RefreshManager singleton.
func GetRefreshManager() *RefreshManager {
managerOnce.Do(func() {
globalRefreshManager = &RefreshManager{}
@@ -33,9 +33,7 @@ func GetRefreshManager() *RefreshManager {
return globalRefreshManager
}
// Initialize 初始化后台刷新器
// baseDir: token 文件所在的目录
// cfg: 应用配置
// Initialize sets up the background refresher.
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -58,18 +56,16 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
baseDir = resolvedBaseDir
}
// 创建 token 存储库
repo := NewFileTokenRepository(baseDir)
// 创建后台刷新器,配置参数
opts := []RefresherOption{
WithInterval(time.Minute), // 每分钟检查一次
WithBatchSize(50), // 每批最多处理 50 个 token
WithConcurrency(10), // 最多 10 个并发刷新
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
WithInterval(time.Minute),
WithBatchSize(50),
WithConcurrency(10),
WithConfig(cfg),
}
// 如果已设置回调,传递给 BackgroundRefresher
// Pass callback to BackgroundRefresher if already set
if m.onTokenRefreshed != nil {
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
}
@@ -80,7 +76,7 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
return nil
}
// Start 启动后台刷新
// Start begins background token refreshing.
func (m *RefreshManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -102,7 +98,7 @@ func (m *RefreshManager) Start() {
log.Info("refresh manager: background refresh started")
}
// Stop 停止后台刷新
// Stop halts background token refreshing.
func (m *RefreshManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -123,14 +119,14 @@ func (m *RefreshManager) Stop() {
log.Info("refresh manager: background refresh stopped")
}
// IsRunning 检查后台刷新是否正在运行
// IsRunning reports whether background refreshing is active.
func (m *RefreshManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.started
}
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
// UpdateBaseDir changes the token directory at runtime.
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -143,16 +139,15 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) {
}
}
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
// 可以在任何时候调用,支持运行时更新回调
// callback: 回调函数,接收 tokenID文件名和新的 token 数据
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
// Can be called at any time; supports runtime callback updates.
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
m.mu.Lock()
defer m.mu.Unlock()
m.onTokenRefreshed = callback
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
// Update the refresher's callback in a thread-safe manner if already created
if m.refresher != nil {
m.refresher.callbackMu.Lock()
m.refresher.onTokenRefreshed = callback
@@ -162,8 +157,11 @@ func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, token
log.Debug("refresh manager: token refresh callback registered")
}
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
// InitializeAndStart initializes and starts background refreshing (convenience method).
func InitializeAndStart(baseDir string, cfg *config.Config) {
// Initialize global fingerprint config
initGlobalFingerprintConfig(cfg)
manager := GetRefreshManager()
if err := manager.Initialize(baseDir, cfg); err != nil {
log.Errorf("refresh manager: initialization failed: %v", err)
@@ -172,7 +170,31 @@ func InitializeAndStart(baseDir string, cfg *config.Config) {
manager.Start()
}
// StopGlobalRefreshManager 停止全局刷新管理器
// initGlobalFingerprintConfig loads fingerprint settings from application config.
func initGlobalFingerprintConfig(cfg *config.Config) {
if cfg == nil || cfg.KiroFingerprint == nil {
return
}
fpCfg := cfg.KiroFingerprint
SetGlobalFingerprintConfig(&FingerprintConfig{
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
OSType: fpCfg.OSType,
OSVersion: fpCfg.OSVersion,
NodeVersion: fpCfg.NodeVersion,
KiroVersion: fpCfg.KiroVersion,
KiroHash: fpCfg.KiroHash,
})
log.Debug("kiro: global fingerprint config loaded")
}
// InitFingerprintConfig initializes the global fingerprint config from application config.
func InitFingerprintConfig(cfg *config.Config) {
initGlobalFingerprintConfig(cfg)
}
// StopGlobalRefreshManager stops the global refresh manager.
func StopGlobalRefreshManager() {
if globalRefreshManager != nil {
globalRefreshManager.Stop()

View File

@@ -84,6 +84,8 @@ type SocialAuthClient struct {
httpClient *http.Client
cfg *config.Config
protocolHandler *ProtocolHandler
machineID string
kiroVersion string
}
// NewSocialAuthClient creates a new social auth client.
@@ -92,10 +94,13 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &SocialAuthClient{
httpClient: client,
cfg: cfg,
protocolHandler: NewProtocolHandler(),
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
@@ -229,7 +234,8 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
@@ -269,7 +275,8 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
@@ -466,7 +473,7 @@ func forceDefaultProtocolHandler() {
if runtime.GOOS != "linux" {
return // Non-Linux platforms use different handler mechanisms
}
// Set our handler as default using xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {

View File

@@ -14,6 +14,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
@@ -40,21 +41,13 @@ const (
// Authorization code flow callback
authCodeCallbackPath = "/oauth/callback"
authCodeCallbackPort = 19877
// User-Agent to match official Kiro IDE
kiroUserAgent = "KiroIDE"
// IDC token refresh headers (matching Kiro IDE behavior)
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
)
// Sentinel errors for OIDC token polling
var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrSlowDown = errors.New("slow_down")
)
// SSOOIDCClient handles AWS SSO OIDC authentication.
type SSOOIDCClient struct {
httpClient *http.Client
cfg *config.Config
@@ -74,10 +67,10 @@ func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
// RegisterClientResponse from AWS SSO OIDC.
type RegisterClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
}
// StartDeviceAuthResponse from AWS SSO OIDC.
@@ -174,8 +167,7 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -220,8 +212,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, cli
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -267,8 +258,7 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -311,8 +301,11 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
return &result, nil
}
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region.
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
if region == "" {
region = defaultIDCRegion
}
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
@@ -331,18 +324,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
if err != nil {
return nil, err
}
// Set headers matching kiro2api's IDC token refresh
// These headers are required for successful IDC token refresh
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
req.Header.Set("Connection", "keep-alive")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "*")
req.Header.Set("sec-fetch-mode", "cors")
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -469,10 +451,10 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
// Fetch user email
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -502,12 +484,36 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
return nil, fmt.Errorf("authorization timed out")
}
// IDCLoginOptions holds optional parameters for IDC login.
type IDCLoginOptions struct {
StartURL string // Pre-configured start URL (skips prompt if set)
Region string // OIDC region for login and token refresh (defaults to us-east-1)
UseDeviceCode bool // Use Device Code flow instead of Auth Code flow
}
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
// Options can be provided to pre-configure IDC parameters (startURL, region).
// If StartURL is provided in opts, IDC flow is used directly without prompting.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// If IDC options with StartURL are provided, skip method selection and use IDC directly
if opts != nil && opts.StartURL != "" {
region := opts.Region
if region == "" {
region = defaultIDCRegion
}
fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL)
fmt.Printf(" Region: %s\n", region)
if opts.UseDeviceCode {
return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region)
}
return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region)
}
// Prompt for login method
options := []string{
"Use with Builder ID (personal AWS account)",
@@ -520,15 +526,41 @@ func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroToke
return c.LoginWithBuilderID(ctx)
}
// IDC flow - prompt for start URL and region
fmt.Println()
startURL := promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
// IDC flow - use pre-configured values or prompt
var startURL, region string
if opts != nil {
startURL = opts.StartURL
region = opts.Region
}
region := promptInput("? Enter Region", defaultIDCRegion)
fmt.Println()
// Use pre-configured startURL or prompt
if startURL == "" {
startURL = promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
}
} else {
fmt.Printf(" Using pre-configured Start URL: %s\n", startURL)
}
// Use pre-configured region or prompt
if region == "" {
region = promptInput("? Enter Region", defaultIDCRegion)
} else {
fmt.Printf(" Using pre-configured Region: %s\n", region)
}
if opts != nil && opts.UseDeviceCode {
return c.LoginWithIDCAndOptions(ctx, startURL, region)
}
return c.LoginWithIDCAuthCode(ctx, startURL, region)
}
// LoginWithIDCAndOptions performs IDC login with the specified region.
func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
return c.LoginWithIDC(ctx, startURL, region)
}
@@ -550,8 +582,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -594,8 +625,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -639,8 +669,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -702,13 +731,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
if err != nil {
return nil, err
}
// Set headers matching Kiro IDE behavior for better compatibility
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept", "*/*")
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -835,12 +858,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
log.Debugf("Failed to close browser: %v", err)
}
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -850,7 +869,7 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ProfileArn: "", // Builder ID has no profile
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
@@ -859,15 +878,15 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
}
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
// Falls back to JWT parsing if userinfo fails.
@@ -931,20 +950,64 @@ func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken str
return ""
}
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
// Try ListProfiles API first
profileArn := c.tryListProfiles(ctx, accessToken)
// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API.
// This is used to get profileArn for imported accounts that may not have it.
func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string {
profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken)
if profileArn != "" {
return profileArn
}
// Fallback: Try ListAvailableCustomizations
return c.tryListCustomizations(ctx, accessToken)
return c.tryListProfilesLegacy(ctx, accessToken)
}
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}"))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/json")
accountKey := GetAccountKey(clientID, refreshToken)
setRuntimeHeaders(req, accessToken, accountKey)
resp, err := c.httpClient.Do(req)
if err != nil {
log.Debugf("ListAvailableProfiles request failed: %v", err)
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListAvailableProfiles response: %s", string(respBody))
var result struct {
Profiles []struct {
Arn string `json:"arn"`
ProfileName string `json:"profileName"`
} `json:"profiles"`
NextToken *string `json:"nextToken"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
log.Debugf("ListAvailableProfiles parse error: %v", err)
return ""
}
if len(result.Profiles) > 0 {
log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn)
return result.Profiles[0].Arn
}
return ""
}
func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
@@ -954,7 +1017,9 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
// Use the legacy CodeWhisperer endpoint for JSON-RPC style requests.
// The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers.
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body)))
if err != nil {
return ""
}
@@ -973,11 +1038,11 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListProfiles response: %s", string(respBody))
log.Debugf("ListProfiles (legacy) response: %s", string(respBody))
var result struct {
Profiles []struct {
@@ -1001,63 +1066,6 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
return ""
}
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
body, err := json.Marshal(payload)
if err != nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
var result struct {
Customizations []struct {
Arn string `json:"arn"`
} `json:"customizations"`
ProfileArn string `json:"profileArn"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if result.ProfileArn != "" {
return result.ProfileArn
}
if len(result.Customizations) > 0 {
return result.Customizations[0].Arn
}
return ""
}
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
payload := map[string]interface{}{
@@ -1078,8 +1086,7 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -1105,6 +1112,53 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
return &result, nil
}
func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]interface{}{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerUrl,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
}
var result RegisterClientResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// AuthCodeCallbackResult contains the result from authorization code callback.
type AuthCodeCallbackResult struct {
Code string
@@ -1128,6 +1182,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
resultChan := make(chan AuthCodeCallbackResult, 1)
doneChan := make(chan struct{})
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
@@ -1147,6 +1202,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- AuthCodeCallbackResult{Error: errParam}
close(doneChan)
return
}
@@ -1156,6 +1212,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
close(doneChan)
return
}
@@ -1164,6 +1221,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
<script>window.close();</script></body></html>`)
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
close(doneChan)
})
server.Handler = mux
@@ -1178,7 +1236,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
select {
case <-ctx.Done():
case <-time.After(10 * time.Minute):
case <-resultChan:
case <-doneChan:
}
_ = server.Shutdown(context.Background())
}()
@@ -1227,8 +1285,54 @@ func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, c
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -1352,12 +1456,118 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
fmt.Println("\n✓ Authentication successful!")
// Step 8: Get profile ARN
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: "", // Builder ID has no profile
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
if region == "" {
region = defaultIDCRegion
}
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
}
state, err := generateStateForAuthCode()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
fmt.Println("\nStarting callback server...")
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
fmt.Println("Registering client...")
regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
log.Debugf("Client registered: %s", regResp.ClientID)
endpoint := getOIDCEndpoint(region)
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge)
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Println(" Opening browser for authentication...")
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n URL: %s\n\n", authURL)
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
} else {
browser.SetIncognitoMode(true)
}
if err := browser.OpenURL(authURL); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" ⚠ Could not open browser automatically.")
fmt.Println(" Please open the URL above in your browser manually.")
} else {
fmt.Println(" (Browser opened automatically)")
}
fmt.Println("\n Waiting for authorization callback...")
select {
case <-ctx.Done():
browser.CloseBrowser()
return nil, ctx.Err()
case <-time.After(10 * time.Minute):
browser.CloseBrowser()
return nil, fmt.Errorf("authorization timed out")
case result := <-resultChan:
if result.Error != "" {
browser.CloseBrowser()
return nil, fmt.Errorf("authorization failed: %s", result.Error)
}
fmt.Println("\n✓ Authorization received!")
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
fmt.Println("Exchanging code for tokens...")
tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
fmt.Println("Fetching profile information...")
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -1369,12 +1579,25 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
AuthMethod: "idc",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
StartURL: startURL,
Region: region,
}, nil
}
}
func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode())
}

View File

@@ -0,0 +1,261 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type recordingRoundTripper struct {
lastReq *http.Request
}
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.lastReq = req
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
}
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("", "refresh-token-789")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
var capturedReq struct {
Method string
Path string
Headers http.Header
Body map[string]interface{}
}
mockResp := RegisterClientResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientIDIssuedAt: 1700000000,
ClientSecretExpiresAt: 1700086400,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq.Method = r.Method
capturedReq.Path = r.URL.Path
capturedReq.Headers = r.Header.Clone()
bodyBytes, _ := io.ReadAll(r.Body)
json.Unmarshal(bodyBytes, &capturedReq.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResp)
}))
defer ts.Close()
// Extract host to build a region that resolves to our test server.
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
// instead inject the test server URL directly via a custom HTTP client transport.
client := &SSOOIDCClient{
httpClient: ts.Client(),
}
// We need to route the request to our test server. Use a transport that rewrites the URL.
client.httpClient.Transport = &rewriteTransport{
base: ts.Client().Transport,
targetURL: ts.URL,
}
resp, err := client.RegisterClientForAuthCodeWithIDC(
context.Background(),
"http://127.0.0.1:19877/oauth/callback",
"https://my-idc-instance.awsapps.com/start",
"us-east-1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify request method and path
if capturedReq.Method != http.MethodPost {
t.Errorf("method = %q, want POST", capturedReq.Method)
}
if capturedReq.Path != "/client/register" {
t.Errorf("path = %q, want /client/register", capturedReq.Path)
}
// Verify headers
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type = %q, want application/json", ct)
}
ua := capturedReq.Headers.Get("User-Agent")
if !strings.Contains(ua, "KiroIDE") {
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
}
if !strings.Contains(ua, "sso-oidc") {
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
}
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
if !strings.Contains(xua, "KiroIDE") {
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
}
// Verify body fields
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
}
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
t.Errorf("clientType = %q, want %q", v, "public")
}
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
}
// Verify scopes array
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
if !ok || len(scopesRaw) != 5 {
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
}
expectedScopes := []string{
"codewhisperer:completions", "codewhisperer:analysis",
"codewhisperer:conversations", "codewhisperer:transformations",
"codewhisperer:taskassist",
}
for i, s := range expectedScopes {
if scopesRaw[i].(string) != s {
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
}
}
// Verify grantTypes
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
if !ok || len(grantTypesRaw) != 2 {
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
}
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
}
// Verify redirectUris
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
if !ok || len(redirectRaw) != 1 {
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
}
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
}
// Verify response parsing
if resp.ClientID != "test-client-id" {
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
}
if resp.ClientSecret != "test-client-secret" {
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
}
if resp.ClientIDIssuedAt != 1700000000 {
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
}
if resp.ClientSecretExpiresAt != 1700086400 {
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
}
}
// rewriteTransport redirects all requests to the test server URL.
type rewriteTransport struct {
base http.RoundTripper
targetURL string
}
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
target, _ := url.Parse(t.targetURL)
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
if t.base != nil {
return t.base.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
func TestBuildAuthorizationURL(t *testing.T) {
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
endpoint := "https://oidc.us-east-1.amazonaws.com"
redirectURI := "http://127.0.0.1:19877/oauth/callback"
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
// Verify colons and commas in scopes are percent-encoded
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
}
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
}
// Parse back and verify all parameters round-trip correctly
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("failed to parse auth URL: %v", err)
}
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
}
q := parsed.Query()
checks := map[string]string{
"response_type": "code",
"client_id": "test-client-id",
"redirect_uri": redirectURI,
"scopes": scopes,
"state": "random-state",
"code_challenge": "test-challenge",
"code_challenge_method": "S256",
}
for key, want := range checks {
if got := q.Get(key); got != want {
t.Errorf("%s = %q, want %q", key, got, want)
}
}
}

View File

@@ -29,7 +29,7 @@ type KiroTokenStorage struct {
ClientID string `json:"client_id,omitempty"`
// ClientSecret is the OAuth client secret (required for token refresh)
ClientSecret string `json:"client_secret,omitempty"`
// Region is the AWS region
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
// StartURL is the AWS Identity Center start URL (for IDC auth)
StartURL string `json:"start_url,omitempty"`

View File

@@ -200,36 +200,22 @@ func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
}
// 解析各字段
if v, ok := metadata["access_token"].(string); ok {
token.AccessToken = v
}
if v, ok := metadata["refresh_token"].(string); ok {
token.RefreshToken = v
}
if v, ok := metadata["client_id"].(string); ok {
token.ClientID = v
}
if v, ok := metadata["client_secret"].(string); ok {
token.ClientSecret = v
}
if v, ok := metadata["region"].(string); ok {
token.Region = v
}
if v, ok := metadata["start_url"].(string); ok {
token.StartURL = v
}
if v, ok := metadata["provider"].(string); ok {
token.Provider = v
}
token.AccessToken, _ = metadata["access_token"].(string)
token.RefreshToken, _ = metadata["refresh_token"].(string)
token.ClientID, _ = metadata["client_id"].(string)
token.ClientSecret, _ = metadata["client_secret"].(string)
token.Region, _ = metadata["region"].(string)
token.StartURL, _ = metadata["start_url"].(string)
token.Provider, _ = metadata["provider"].(string)
// 解析时间字段
if v, ok := metadata["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
token.ExpiresAt = t
}
}
if v, ok := metadata["last_refresh"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
token.LastVerified = t
}
}

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -51,14 +50,12 @@ type QuotaStatus struct {
// UsageChecker provides methods for checking token quota usage.
type UsageChecker struct {
httpClient *http.Client
endpoint string
}
// NewUsageChecker creates a new UsageChecker instance.
func NewUsageChecker(cfg *config.Config) *UsageChecker {
return &UsageChecker{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
@@ -66,7 +63,6 @@ func NewUsageChecker(cfg *config.Config) *UsageChecker {
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
return &UsageChecker{
httpClient: client,
endpoint: awsKiroEndpoint,
}
}
@@ -80,26 +76,23 @@ func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData)
return nil, fmt.Errorf("access token is empty")
}
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Use endpoint from profileArn if available
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", targetGetUsage)
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
req.Header.Set("Accept", "application/json")
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := c.httpClient.Do(req)
if err != nil {

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
if err != nil {
var authErr *claude.AuthenticationError
if errors.As(err, &authErr) {
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
log.Error(claude.GetUserFriendlyMessage(authErr))
if authErr.Type == claude.ErrPortInUse.Type {
os.Exit(claude.ErrPortInUse.Code)

View File

@@ -22,6 +22,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
)
return manager
}

View File

@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
if err != nil {
var emailErr *sdkAuth.EmailRequiredError
if errors.As(err, &emailErr) {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}

View File

@@ -0,0 +1,54 @@
package cmd
import (
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
// It initiates the device-based authentication process for Kilo AI services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = func(prompt string) (string, error) {
fmt.Print(prompt)
var value string
fmt.Scanln(&value)
return strings.TrimSpace(value), nil
}
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
if err != nil {
fmt.Printf("Kilo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Kilo authentication successful!")
}

View File

@@ -206,3 +206,52 @@ func DoKiroImport(cfg *config.Config, options *LoginOptions) {
}
fmt.Println("Kiro token import successful!")
}
func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) {
if options == nil {
options = &LoginOptions{}
}
if startURL == "" {
log.Errorf("Kiro IDC login requires --kiro-idc-start-url")
fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start")
return
}
manager := newAuthManager()
authenticator := sdkAuth.NewKiroAuthenticator()
metadata := map[string]string{
"start-url": startURL,
"region": region,
"flow": flow,
}
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: metadata,
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro IDC authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure your IDC Start URL is correct")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device")
return
}
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro IDC authentication successful!")
}

View File

@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
log.Info("Authentication successful.")
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Errorf("Failed to get project list: %v", errProjects)
return
var activatedProjects []string
useGoogleOne := false
if trimmedProjectID == "" && promptFn != nil {
fmt.Println("\nSelect login mode:")
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
fmt.Println(" 2. Google One (personal account, auto-discover project)")
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
useGoogleOne = true
}
}
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Error("No project selected; aborting login.")
return
}
activatedProjects := make([]string, 0, len(projectSelections))
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
var projectErr *projectSelectionRequiredError
if errors.As(errSetup, &projectErr) {
log.Error("Failed to start user onboarding: A project ID is required.")
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Errorf("Failed to complete user setup: %v", errSetup)
if useGoogleOne {
log.Info("Google One mode: auto-discovering project...")
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidateID
autoProject := strings.TrimSpace(storage.ProjectID)
if autoProject == "" {
log.Error("Google One auto-discovery returned empty project ID")
return
}
log.Infof("Auto-discovered project: %s", autoProject)
activatedProjects = []string{autoProject}
} else {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Errorf("Failed to get project list: %v", errProjects)
return
}
// Skip duplicates
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Error("No project selected; aborting login.")
return
}
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
log.Error("Failed to start user onboarding: A project ID is required.")
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Errorf("Failed to complete user setup: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidateID
}
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
storage.Auto = false
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
return
}
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
if record.Metadata == nil {
record.Metadata = make(map[string]any)

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
var authErr *codex.AuthenticationError
if errors.As(err, &authErr) {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)

View File

@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
if err != nil {
var emailErr *sdkAuth.EmailRequiredError
if errors.As(err, &emailErr) {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}

View File

@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {

View File

@@ -87,6 +87,10 @@ type Config struct {
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
// KiroFingerprint defines a global fingerprint configuration for all Kiro requests.
// When set, all Kiro requests will use this fixed fingerprint instead of random generation.
KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"`
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
@@ -97,6 +101,10 @@ type Config struct {
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
@@ -130,6 +138,15 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -301,6 +318,10 @@ type CloakConfig struct {
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
// CacheUserID controls whether Claude user_id values are cached per API key.
// When false, a fresh random user_id is generated for every request.
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
@@ -368,6 +389,9 @@ type CodexKey struct {
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
@@ -457,6 +481,9 @@ type KiroKey struct {
// Region is the AWS region (default: us-east-1).
Region string `yaml:"region,omitempty" json:"region,omitempty"`
// StartURL is the IAM Identity Center (IDC) start URL for SSO login.
StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this configuration.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
@@ -469,6 +496,20 @@ type KiroKey struct {
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
}
// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests.
// When configured, all Kiro requests will use this fixed fingerprint instead of random generation.
// Empty fields will fall back to random selection from built-in pools.
type KiroFingerprintConfig struct {
OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"`
RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"`
StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"`
OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"`
OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"`
NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"`
KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"`
KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"`
}
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -632,9 +673,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
}
// Sync request authentication providers with inline API keys for backwards compatibility.
syncInlineAccessProvider(&cfg)
// Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys()
@@ -739,14 +777,46 @@ func payloadRawString(value any) ([]byte, bool) {
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
// when no user-configured aliases exist for those channels.
func (cfg *Config) SanitizeOAuthModelAlias() {
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
if cfg == nil {
return
}
// Inject channel defaults when the channel is absent in user config.
// Presence is checked case-insensitively and includes explicit nil/empty markers.
if cfg.OAuthModelAlias == nil {
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
}
hasChannel := func(channel string) bool {
for k := range cfg.OAuthModelAlias {
if strings.EqualFold(strings.TrimSpace(k), channel) {
return true
}
}
return false
}
if !hasChannel("kiro") {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
if !hasChannel("github-copilot") {
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
}
if len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
if channel == "" {
continue
}
// Preserve channels that were explicitly set to empty/nil they act
// as "disabled" markers so default injection won't re-add them (#222).
if len(aliases) == 0 {
out[channel] = nil
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
@@ -888,18 +958,6 @@ func normalizeModelPrefix(prefix string) string {
return trimmed
}
func syncInlineAccessProvider(cfg *Config) {
if cfg == nil {
return
}
if len(cfg.APIKeys) == 0 {
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
}
}
cfg.Access.Providers = nil
}
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
func looksLikeBcrypt(s string) bool {
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
@@ -987,7 +1045,7 @@ func hashSecret(secret string) (string, error) {
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
persistCfg := sanitizeConfigForPersist(cfg)
persistCfg := cfg
// Load original YAML as a node tree to preserve comments and ordering.
data, err := os.ReadFile(configFile)
if err != nil {
@@ -1055,16 +1113,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
return err
}
func sanitizeConfigForPersist(cfg *Config) *Config {
if cfg == nil {
return nil
}
clone := *cfg
clone.SDKConfig = cfg.SDKConfig
clone.SDKConfig.Access = AccessConfig{}
return &clone
}
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
// while preserving comments and positions.
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
@@ -1161,8 +1209,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) {
// value is non-zero and not a known default to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
var currentPath []string
if len(path) > 0 {
currentPath = path[0]
}
if dst == nil || src == nil {
return
}
@@ -1176,16 +1229,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
childPath := appendPath(currentPath, sk.Value)
if idx >= 0 {
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv)
mergeNodePreserve(dv, sv, childPath)
} else {
// New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
// New key: only add if value is non-zero and not a known default
candidate := deepCopyNode(sv)
pruneKnownDefaultsInNewNode(childPath, candidate)
if isKnownDefaultValue(childPath, candidate) {
continue
}
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
}
}
}
@@ -1193,7 +1249,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
// reusing destination nodes to keep comments and anchors. For sequences, it updates
// in-place by index.
func mergeNodePreserve(dst, src *yaml.Node) {
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
var currentPath []string
if len(path) > 0 {
currentPath = path[0]
}
if dst == nil || src == nil {
return
}
@@ -1202,7 +1263,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
if dst.Kind != yaml.MappingNode {
copyNodeShallow(dst, src)
}
mergeMappingPreserve(dst, src)
mergeMappingPreserve(dst, src, currentPath)
case yaml.SequenceNode:
// Preserve explicit null style if dst was null and src is empty sequence
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
@@ -1225,7 +1286,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
dst.Content[i] = deepCopyNode(src.Content[i])
continue
}
mergeNodePreserve(dst.Content[i], src.Content[i])
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
if dst.Content[i] != nil && src.Content[i] != nil &&
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
pruneMissingMapKeys(dst.Content[i], src.Content[i])
@@ -1267,6 +1328,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1
}
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
func appendPath(path []string, key string) []string {
if len(path) == 0 {
return []string{key}
}
newPath := make([]string, len(path)+1)
copy(newPath, path)
newPath[len(path)] = key
return newPath
}
// isKnownDefaultValue returns true if the given node at the specified path
// represents a known default value that should not be written to the config file.
// This prevents non-zero defaults from polluting the config.
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
// First check if it's a zero value
if isZeroValueNode(node) {
return true
}
// Match known non-zero defaults by exact dotted path.
if len(path) == 0 {
return false
}
fullPath := strings.Join(path, ".")
// Check string defaults
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
switch fullPath {
case "pprof.addr":
return node.Value == DefaultPprofAddr
case "remote-management.panel-github-repository":
return node.Value == DefaultPanelGitHubRepository
case "routing.strategy":
return node.Value == "round-robin"
}
}
// Check integer defaults
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
switch fullPath {
case "error-logs-max-files":
return node.Value == "10"
}
}
return false
}
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
// before it is appended into the destination YAML tree.
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
if node == nil {
return
}
switch node.Kind {
case yaml.MappingNode:
filtered := make([]*yaml.Node, 0, len(node.Content))
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]
if keyNode == nil || valueNode == nil {
continue
}
childPath := appendPath(path, keyNode.Value)
if isKnownDefaultValue(childPath, valueNode) {
continue
}
pruneKnownDefaultsInNewNode(childPath, valueNode)
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
len(valueNode.Content) == 0 {
continue
}
filtered = append(filtered, keyNode, valueNode)
}
node.Content = filtered
case yaml.SequenceNode:
for _, child := range node.Content {
pruneKnownDefaultsInNewNode(path, child)
}
}
}
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.

View File

@@ -20,6 +20,45 @@ var antigravityModelConversionTable = map[string]string{
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
}
// defaultKiroAliases returns the default oauth-model-alias configuration
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
// names so that clients like Claude Code can use standard names directly.
func defaultKiroAliases() []OAuthModelAlias {
return []OAuthModelAlias{
// Sonnet 4.6
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
// Sonnet 4.5
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
// Sonnet 4
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
// Opus 4.6
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
// Opus 4.5
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
// Haiku 4.5
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
}
}
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
// This keeps compatibility with clients (e.g. Claude Code) that use
// Anthropic-style model IDs like "claude-opus-4-6".
func defaultGitHubCopilotAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
}
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {

View File

@@ -54,3 +54,208 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
}
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
// When no kiro aliases are configured, defaults should be injected
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) == 0 {
t.Fatal("expected default kiro aliases to be injected")
}
// Check that standard Claude model names are present
aliasSet := make(map[string]bool)
for _, a := range kiroAliases {
aliasSet[a.Alias] = true
}
expectedAliases := []string{
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5",
"claude-sonnet-4-20250514",
"claude-sonnet-4",
"claude-opus-4-6",
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-haiku-4-5-20251001",
"claude-haiku-4-5",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default kiro alias %q to be present", expected)
}
}
// All should have fork=true
for _, a := range kiroAliases {
if !a.Fork {
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
// Codex aliases should still be preserved
if len(cfg.OAuthModelAlias["codex"]) != 1 {
t.Fatal("expected codex aliases to be preserved")
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) == 0 {
t.Fatal("expected default github-copilot aliases to be injected")
}
aliasSet := make(map[string]bool, len(copilotAliases))
for _, a := range copilotAliases {
aliasSet[a.Alias] = true
if !a.Fork {
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
expectedAliases := []string{
"claude-haiku-4-5",
"claude-opus-4-1",
"claude-opus-4-5",
"claude-opus-4-6",
"claude-sonnet-4-5",
"claude-sonnet-4-6",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default github-copilot alias %q to be present", expected)
}
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
// When user has configured kiro aliases, defaults should NOT be injected
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": {
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) != 1 {
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
}
if kiroAliases[0].Alias != "my-custom-sonnet" {
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": {
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 1 {
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
}
if copilotAliases[0].Alias != "my-opus" {
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
// When user explicitly deletes kiro aliases (key exists with nil value),
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": nil, // explicitly deleted
"codex": {{Name: "gpt-5", Alias: "g5"}},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) != 0 {
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
}
// The key itself must still be present to prevent re-injection on next reload
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
}
// Other channels should be unaffected
if len(cfg.OAuthModelAlias["codex"]) != 1 {
t.Fatal("expected codex aliases to be preserved")
}
}
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": nil, // explicitly deleted
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 0 {
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
}
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
// Same as above but with empty slice instead of nil (PUT with empty body).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": {}, // explicitly set to empty
},
}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
}
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved")
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
// When OAuthModelAlias is nil, kiro defaults should still be injected
cfg := &Config{}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) == 0 {
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
}
}

View File

@@ -20,8 +20,9 @@ type SDKConfig struct {
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
@@ -42,65 +43,3 @@ type StreamingConfig struct {
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
// AccessConfig groups request authentication providers.
type AccessConfig struct {
// Providers lists configured authentication providers.
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
}
// AccessProvider describes a request authentication provider entry.
type AccessProvider struct {
// Name is the instance identifier for the provider.
Name string `yaml:"name" json:"name"`
// Type selects the provider implementation registered via the SDK.
Type string `yaml:"type" json:"type"`
// SDK optionally names a third-party SDK module providing this provider.
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
// APIKeys lists inline keys for providers that require them.
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
// Config passes provider-specific options to the implementation.
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
}
const (
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
AccessProviderTypeConfigAPIKey = "config-api-key"
// DefaultAccessProviderName is applied when no provider name is supplied.
DefaultAccessProviderName = "config-inline"
)
// ConfigAPIKeyProvider returns the first inline API key provider if present.
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
if c == nil {
return nil
}
for i := range c.Access.Providers {
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
if c.Access.Providers[i].Name == "" {
c.Access.Providers[i].Name = DefaultAccessProviderName
}
return &c.Access.Providers[i]
}
}
return nil
}
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
// It returns nil when no keys are supplied.
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
if len(keys) == 0 {
return nil
}
provider := &AccessProvider{
Name: DefaultAccessProviderName,
Type: AccessProviderTypeConfigAPIKey,
APIKeys: append([]string(nil), keys...),
}
return provider
}

View File

@@ -27,4 +27,7 @@ const (
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
Kiro = "kiro"
// Kilo represents the Kilo AI provider identifier.
Kilo = "kilo"
)

View File

@@ -21,6 +21,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)
const (
@@ -28,6 +29,7 @@ const (
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
managementSyncMinInterval = 30 * time.Second
updateCheckInterval = 3 * time.Hour
)
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
var (
lastUpdateCheckMu sync.Mutex
lastUpdateCheckTime time.Time
currentConfigPtr atomic.Pointer[config.Config]
disableControlPanel atomic.Bool
schedulerOnce sync.Once
schedulerConfigPath atomic.Value
sfGroup singleflight.Group
)
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
currentConfigPtr.Store(nil)
return
}
prevDisabled := disableControlPanel.Load()
currentConfigPtr.Store(cfg)
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
lastUpdateCheckMu.Lock()
lastUpdateCheckTime = time.Time{}
lastUpdateCheckMu.Unlock()
}
}
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
log.Debug("management asset auto-updater skipped: config not yet available")
return
}
if disableControlPanel.Load() {
if cfg.RemoteManagement.DisableControlPanel {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
}
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
// The function is designed to run in a background goroutine and will never panic.
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
if ctx == nil {
ctx = context.Background()
}
if disableControlPanel.Load() {
log.Debug("management asset sync skipped: control panel disabled by configuration")
return
}
staticDir = strings.TrimSpace(staticDir)
if staticDir == "" {
log.Debug("management asset sync skipped: empty static directory")
return
return false
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
if timeSinceLastCheck < updateCheckInterval {
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
lastUpdateCheckMu.Lock()
now := time.Now()
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
lastUpdateCheckMu.Unlock()
log.Debugf(
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
timeSinceLastAttempt.Round(time.Second),
managementSyncMinInterval,
)
return nil, nil
}
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
return
}
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.WithError(err).Debug("failed to read local management asset hash")
}
localHash = ""
}
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
log.Debug("management asset is already up to date")
return
}
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return nil, nil
}
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.WithError(err).Debug("failed to read local management asset hash")
}
return
localHash = ""
}
log.WithError(err).Warn("failed to download management asset")
return
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
}
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return nil, nil
}
return nil, nil
}
log.WithError(err).Warn("failed to fetch latest management release information")
return nil, nil
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to update management asset on disk")
return
}
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
log.Debug("management asset is already up to date")
return nil, nil
}
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return nil, nil
}
return nil, nil
}
log.WithError(err).Warn("failed to download management asset")
return nil, nil
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to update management asset on disk")
return nil, nil
}
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
return nil, nil
})
_, err := os.Stat(localPath)
return err == nil
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -0,0 +1,21 @@
// Package registry provides model definitions for various AI service providers.
package registry
// GetKiloModels returns the Kilo model definitions
func GetKiloModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kilo/auto",
Object: "model",
Created: 1732752000,
OwnedBy: "kilo",
Type: "kilo",
DisplayName: "Kilo Auto",
Description: "Automatic model selection by Kilo",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -19,7 +19,9 @@ import (
// - codex
// - qwen
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - kiro
// - amazonq
@@ -43,10 +45,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetQwenModels()
case "iflow":
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "github-copilot":
return GetGitHubCopilotModels()
case "kiro":
return GetKiroModels()
case "kilo":
return GetKiloModels()
case "amazonq":
return GetAmazonQModels()
case "antigravity":
@@ -93,8 +99,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
}
for _, models := range allModels {
@@ -121,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {
now := int64(1732752000) // 2024-11-27
return []*ModelInfo{
gpt4oEntries := []struct {
ID string
DisplayName string
Description string
}{
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
}
models := []*ModelInfo{
{
ID: "gpt-4.1",
Object: "model",
@@ -133,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
}
for _, entry := range gpt4oEntries {
models = append(models, &ModelInfo{
ID: entry.ID,
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: entry.DisplayName,
Description: entry.Description,
ContextLength: 128000,
MaxCompletionTokens: 16384,
})
}
return append(models, []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
@@ -144,6 +181,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-mini",
@@ -156,6 +194,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-codex",
@@ -168,6 +207,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1",
@@ -180,6 +220,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex",
@@ -192,6 +233,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-mini",
@@ -204,6 +246,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-max",
@@ -216,6 +259,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2",
@@ -228,6 +272,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2-codex",
@@ -240,6 +285,20 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.3 Codex",
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
@@ -277,6 +336,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Opus 4.6",
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4",
Object: "model",
@@ -301,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4.6",
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
@@ -323,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -357,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}
}...)
}
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
@@ -388,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6",
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5",
Object: "model",
@@ -436,6 +542,87 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
// --- 第三方模型 (通过 Kiro 接入) ---
{
ID: "kiro-deepseek-3-2",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2",
Description: "DeepSeek 3.2 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1",
Description: "MiniMax M2.1 via Kiro",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next",
Description: "Qwen3 Coder Next via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-gpt-4o",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4o",
Description: "OpenAI GPT-4o via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4",
Description: "OpenAI GPT-4 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 8192,
},
{
ID: "kiro-gpt-4-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4 Turbo",
Description: "OpenAI GPT-4 Turbo via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-3-5-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-3.5 Turbo",
Description: "OpenAI GPT-3.5 Turbo via Kiro",
ContextLength: 16384,
MaxCompletionTokens: 4096,
},
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
{
ID: "kiro-claude-opus-4-6-agentic",
@@ -449,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6-agentic",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5-agentic",
Object: "model",
@@ -497,6 +696,42 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-deepseek-3-2-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-6",
Object: "model",
@@ -40,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 128000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771286400, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
Description: "Best combination of speed and intelligence",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -742,6 +825,20 @@ func GetOpenAIModels() []*ModelInfo {
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex-spark",
Object: "model",
Created: 1770912000,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.3",
DisplayName: "GPT 5.3 Codex Spark",
Description: "Ultra-fast coding model.",
ContextLength: 128000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
}
}
@@ -774,6 +871,19 @@ func GetQwenModels() []*ModelInfo {
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "coder-model",
Object: "model",
Created: 1771171200,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.5",
DisplayName: "Qwen 3.5 Plus",
Description: "efficient hybrid model with leading coding performance",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
@@ -806,18 +916,12 @@ func GetIFlowModels() []*ModelInfo {
Created int64
Thinking *ThinkingSupport
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
@@ -826,10 +930,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -863,11 +964,15 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}

View File

@@ -601,8 +601,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
registration.QuotaExceededClients[clientID] = new(time.Now())
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}

View File

@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
}
}(firstEvent)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.

View File

@@ -54,8 +54,78 @@ const (
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -232,7 +302,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -436,7 +506,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -645,7 +715,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -775,7 +845,6 @@ attemptLoop:
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
@@ -820,7 +889,7 @@ attemptLoop:
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
switch {
@@ -968,7 +1037,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode
@@ -1008,8 +1077,8 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return nil
}
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
@@ -1021,7 +1090,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
return nil
return fallbackAntigravityPrimaryModels()
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
@@ -1033,13 +1102,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
@@ -1051,19 +1120,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
@@ -1108,9 +1185,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return nil
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {

View File

@@ -0,0 +1,159 @@
package executor
import (
"context"
"encoding/json"
"io"
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
decl := extractFirstFunctionDeclaration(t, body)
if _, ok := decl["parametersJsonSchema"]; ok {
t.Fatalf("parametersJsonSchema should be renamed to parameters")
}
params, ok := decl["parameters"].(map[string]any)
if !ok {
t.Fatalf("parameters missing or invalid type")
}
assertSchemaSanitizedAndPropertyPreserved(t, params)
}
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
decl := extractFirstFunctionDeclaration(t, body)
params, ok := decl["parameters"].(map[string]any)
if !ok {
t.Fatalf("parameters missing or invalid type")
}
assertSchemaSanitizedAndPropertyPreserved(t, params)
}
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
t.Helper()
executor := &AntigravityExecutor{}
auth := &cliproxyauth.Auth{}
payload := []byte(`{
"request": {
"tools": [
{
"function_declarations": [
{
"name": "tool_1",
"parametersJsonSchema": {
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "root-schema",
"type": "object",
"properties": {
"$id": {"type": "string"},
"arg": {
"type": "object",
"prefill": "hello",
"properties": {
"mode": {
"type": "string",
"enum": ["a", "b"],
"enumTitles": ["A", "B"]
}
}
}
},
"patternProperties": {
"^x-": {"type": "string"}
}
}
}
]
}
]
}
}`)
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
if err != nil {
t.Fatalf("buildRequest error: %v", err)
}
raw, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("read request body error: %v", err)
}
var body map[string]any
if err := json.Unmarshal(raw, &body); err != nil {
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
}
return body
}
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
t.Helper()
request, ok := body["request"].(map[string]any)
if !ok {
t.Fatalf("request missing or invalid type")
}
tools, ok := request["tools"].([]any)
if !ok || len(tools) == 0 {
t.Fatalf("tools missing or empty")
}
tool, ok := tools[0].(map[string]any)
if !ok {
t.Fatalf("first tool invalid type")
}
decls, ok := tool["function_declarations"].([]any)
if !ok || len(decls) == 0 {
t.Fatalf("function_declarations missing or empty")
}
decl, ok := decls[0].(map[string]any)
if !ok {
t.Fatalf("first function declaration invalid type")
}
return decl
}
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
t.Helper()
if _, ok := params["$id"]; ok {
t.Fatalf("root $id should be removed from schema")
}
if _, ok := params["patternProperties"]; ok {
t.Fatalf("patternProperties should be removed from schema")
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatalf("properties missing or invalid type")
}
if _, ok := props["$id"]; !ok {
t.Fatalf("property named $id should be preserved")
}
arg, ok := props["arg"].(map[string]any)
if !ok {
t.Fatalf("arg property missing or invalid type")
}
if _, ok := arg["prefill"]; ok {
t.Fatalf("prefill should be removed from nested schema")
}
argProps, ok := arg["properties"].(map[string]any)
if !ok {
t.Fatalf("arg.properties missing or invalid type")
}
mode, ok := argProps["mode"].(map[string]any)
if !ok {
t.Fatalf("mode property missing or invalid type")
}
if _, ok := mode["enumTitles"]; ok {
t.Fatalf("enumTitles should be removed from nested schema")
}
}

View File

@@ -0,0 +1,90 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
} else {
r.Header.Set("Accept", "application/json")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
if !contentBlock.Exists() {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil
return "auto", false, nil, false
}
cloakMode := auth.Attributes["cloak_mode"]
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
}
}
return cloakMode, strictMode, sensitiveWords
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
func injectFakeUserID(payload []byte) []byte {
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
}
return generateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
}
// Inject fake user ID
payload = injectFakeUserID(payload)
payload = injectFakeUserID(payload, apiKey, cacheUserID)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {

View File

@@ -2,9 +2,18 @@ package executor
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userID := gjson.GetBytes(body, "metadata.user_id").String()
model := gjson.GetBytes(body, "model").String()
userIDs = append(userIDs, userID)
requestModels = append(requestModels, model)
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
cacheEnabled := true
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "key-123",
BaseURL: server.URL,
Cloak: &config.CloakConfig{
CacheUserID: &cacheEnabled,
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
for _, model := range models {
t.Logf("Sending request for model: %s", model)
modelPayload, _ := sjson.SetBytes(payload, "model", model)
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: model,
Payload: modelPayload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute(%s) error: %v", model, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
for i := 0; i < 2; i++ {
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute call %d error: %v", i, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}

View File

@@ -28,8 +28,8 @@ import (
)
const (
codexClientVersion = "0.98.0"
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexClientVersion = "0.101.0"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
)
var dataTag = []byte("data:")
@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
@@ -358,11 +358,10 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
@@ -675,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter
}
return err
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil
}
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
return nil
}
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
resetAtTime := time.Unix(resetsAt, 0)
if resetAtTime.After(now) {
retryAfter := resetAtTime.Sub(now)
return &retryAfter
}
}
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
retryAfter := time.Duration(resetsInSeconds) * time.Second
return &retryAfter
}
return nil
}
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""

View File

@@ -0,0 +1,65 @@
package executor
import (
"net/http"
"strconv"
"testing"
"time"
)
func TestParseCodexRetryAfter(t *testing.T) {
now := time.Unix(1_700_000_000, 0)
t.Run("resets_in_seconds", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 123*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
}
})
t.Run("prefers resets_at", func(t *testing.T) {
resetAt := now.Add(5 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 5*time.Minute {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
}
})
t.Run("fallback when resets_at is past", func(t *testing.T) {
resetAt := now.Add(-1 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 77*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
}
})
t.Run("non-429 status code", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
t.Fatalf("expected nil for non-429, got %v", *got)
}
})
t.Run("non usage_limit_reached error type", func(t *testing.T) {
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
}
})
}
func itoa(v int64) string {
return strconv.FormatInt(v, 10)
}

File diff suppressed because it is too large Load Diff

View File

@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
if len(lastBody) > 0 {
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
@@ -899,8 +898,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
seconds, err := strconv.Atoi(matches[1])
if err == nil {
duration := time.Duration(seconds) * time.Second
return &duration, nil
return new(time.Duration(seconds) * time.Second), nil
}
}
}

View File

@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).

View File

@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
package executor
import (
"net/http"
"strings"
"testing"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
t.Parallel()
tests := []struct {
name string
model string
wantModel string
}{
{
name: "suffix stripped",
model: "claude-opus-4.6(medium)",
wantModel: "claude-opus-4.6",
},
{
name: "no suffix unchanged",
model: "claude-opus-4.6",
wantModel: "claude-opus-4.6",
},
{
name: "different suffix stripped",
model: "gpt-4o(high)",
wantModel: "gpt-4o",
},
{
name: "numeric suffix stripped",
model: "gemini-2.5-pro(8192)",
wantModel: "gemini-2.5-pro",
},
}
e := &GitHubCopilotExecutor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
got := e.normalizeModel(tt.model, body)
gotModel := gjson.GetBytes(got, "model").String()
if gotModel != tt.wantModel {
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
}
})
}
}
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
t.Fatal("expected openai-response source to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
t.Fatal("expected codex model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
}
}
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
got := normalizeGitHubCopilotChatTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
}
}
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
got := normalizeGitHubCopilotChatTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
t.Parallel()
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if !in.IsArray() {
t.Fatalf("input type = %v, want array", in.Type)
}
raw := in.Raw
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
t.Fatalf("input = %s, want structured array with all texts", raw)
}
if gjson.GetBytes(got, "messages").Exists() {
t.Fatal("messages should be removed after conversion")
}
if gjson.GetBytes(got, "system").Exists() {
t.Fatal("system should be removed after conversion")
}
}
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
t.Parallel()
body := []byte(`{"input":{"foo":"bar"}}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "foo") {
t.Fatalf("input = %q, want stringified object", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("name").String() != "sum" {
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be preserved")
}
}
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 2 {
t.Fatalf("tools len = %d, want 2", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
}
if tools[0].Get("name").String() != "Bash" {
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
}
if tools[0].Get("description").String() != "Run commands" {
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be set from input_schema")
}
if tools[0].Get("parameters.properties.command").Exists() != true {
t.Fatal("expected parameters.properties.command to exist")
}
if tools[1].Get("name").String() != "Read" {
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
}
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function"}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
}
}
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
t.Parallel()
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}
// --- Tests for X-Initiator detection logic (Problem L) ---
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Claude Code typical flow: last message is user (tool result), but has assistant in history
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
}
}
// --- Tests for vision detection (Problem P) ---
func TestDetectVisionContent_WithImageURL(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected vision content to be detected")
}
}
func TestDetectVisionContent_WithImageType(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected image type to be detected")
}
}
func TestDetectVisionContent_NoVision(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content")
}
}
func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Parallel()
// After Responses API normalization, messages is removed — detection should return false
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content when messages field is absent")
}
}

View File

@@ -4,12 +4,16 @@ import (
"bufio"
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/google/uuid"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
@@ -165,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -258,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -290,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -453,6 +456,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+apiKey)
r.Header.Set("User-Agent", iflowUserAgent)
// Generate session-id
sessionID := "session-" + generateUUID()
r.Header.Set("session-id", sessionID)
// Generate timestamp and signature
timestamp := time.Now().UnixMilli()
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
if signature != "" {
r.Header.Set("x-iflow-signature", signature)
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
@@ -460,6 +477,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
}
}
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
// The signature payload format is: userAgent:sessionId:timestamp
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
if apiKey == "" {
return ""
}
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
h := hmac.New(sha256.New, []byte(apiKey))
h.Write([]byte(payload))
return hex.EncodeToString(h.Sum(nil))
}
// generateUUID generates a random UUID v4 string.
func generateUUID() string {
return uuid.New().String()
}
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""

View File

@@ -0,0 +1,460 @@
package executor
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// KiloExecutor handles requests to Kilo API.
type KiloExecutor struct {
cfg *config.Config
}
// NewKiloExecutor creates a new Kilo executor instance.
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
return &KiloExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *KiloExecutor) Identifier() string { return "kilo" }
// PrepareRequest prepares the HTTP request before execution.
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, _ := kiloCredentials(auth)
if strings.TrimSpace(accessToken) == "" {
return fmt.Errorf("kilo: missing access token")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("kilo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer httpResp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer httpResp.Body.Close()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh validates the Kilo token.
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("missing auth")
}
return auth, nil
}
// CountTokens returns the token count for the given request.
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
}
// kiloCredentials extracts access token and other info from auth.
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
if auth == nil {
return "", ""
}
// Prefer kilocode specific keys, then fall back to generic keys.
// Check metadata first, then attributes.
if auth.Metadata != nil {
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
accessToken = token
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
accessToken = token
}
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
orgID = org
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
orgID = org
}
}
if accessToken == "" && auth.Attributes != nil {
if token := auth.Attributes["kilocodeToken"]; token != "" {
accessToken = token
} else if token := auth.Attributes["access_token"]; token != "" {
accessToken = token
}
}
if orgID == "" && auth.Attributes != nil {
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
orgID = org
} else if org := auth.Attributes["organization_id"]; org != "" {
orgID = org
}
}
return accessToken, orgID
}
// FetchKiloModels fetches models from Kilo API.
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
return registry.GetKiloModels()
}
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
if err != nil {
log.Warnf("kilo: failed to create model fetch request: %v", err)
return registry.GetKiloModels()
}
req.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
req.Header.Set("X-Kilocode-OrganizationID", orgID)
}
req.Header.Set("User-Agent", "cli-proxy-kilo")
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Warnf("kilo: fetch models canceled: %v", err)
} else {
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
}
return registry.GetKiloModels()
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Warnf("kilo: failed to read models response: %v", err)
return registry.GetKiloModels()
}
if resp.StatusCode != http.StatusOK {
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
return registry.GetKiloModels()
}
result := gjson.GetBytes(body, "data")
if !result.Exists() {
// Try root if data field is missing
result = gjson.ParseBytes(body)
if !result.IsArray() {
log.Debugf("kilo: response body: %s", string(body))
log.Warn("kilo: invalid API response format (expected array or data field with array)")
return registry.GetKiloModels()
}
}
var dynamicModels []*registry.ModelInfo
now := time.Now().Unix()
count := 0
totalCount := 0
result.ForEach(func(key, value gjson.Result) bool {
totalCount++
id := value.Get("id").String()
pIdxResult := value.Get("preferredIndex")
preferredIndex := pIdxResult.Int()
// Filter models where preferredIndex > 0 (Kilo-curated models)
if preferredIndex <= 0 {
return true
}
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
if !isFree {
// Check pricing as fallback
promptPricing := value.Get("pricing.prompt").String()
if promptPricing == "0" || promptPricing == "0.0" {
isFree = true
}
}
if !isFree {
log.Debugf("kilo: skipping curated paid model: %s", id)
return true
}
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
dynamicModels = append(dynamicModels, &registry.ModelInfo{
ID: id,
DisplayName: value.Get("name").String(),
ContextLength: int(value.Get("context_length").Int()),
OwnedBy: "kilo",
Type: "kilo",
Object: "model",
Created: now,
})
count++
return true
})
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
if count == 0 && totalCount > 0 {
log.Warn("kilo: no curated free models found (check API response fields)")
}
staticModels := registry.GetKiloModels()
// Always include kilo/auto (first static model)
allModels := append(staticModels[:1], dynamicModels...)
return allModels
}

View File

@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request to Kimi.
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for Kimi requests.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,423 @@
package executor
import (
"fmt"
"testing"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestBuildKiroEndpointConfigs(t *testing.T) {
tests := []struct {
name string
region string
expectedURL string
expectedOrigin string
expectedName string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "us-east-1",
region: "us-east-1",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "eu-west-1",
region: "eu-west-1",
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configs := buildKiroEndpointConfigs(tt.region)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Check primary endpoint (AmazonQ)
primary := configs[0]
if primary.URL != tt.expectedURL {
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
}
if primary.Origin != tt.expectedOrigin {
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
}
if primary.Name != tt.expectedName {
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
}
if primary.AmzTarget != "" {
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
}
// Check fallback endpoint (CodeWhisperer)
fallback := configs[1]
if fallback.Name != "CodeWhisperer" {
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
}
// CodeWhisperer fallback uses the same region as Q endpoint
expectedRegion := tt.region
if expectedRegion == "" {
expectedRegion = kiroDefaultRegion
}
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
if fallback.URL != expectedFallbackURL {
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
}
if fallback.AmzTarget == "" {
t.Error("fallback AmzTarget should NOT be empty")
}
})
}
}
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
configs := getKiroEndpointConfigs(nil)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Should return default us-east-1 configs
if configs[0].Name != "AmazonQ" {
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
}
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"api_region": "eu-central-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
// api_region should take precedence over profile_arn
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
tests := []struct {
name string
preference string
expectedFirstName string
}{
{
name: "Prefer codewhisperer",
preference: "codewhisperer",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer ide (alias for codewhisperer)",
preference: "ide",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer amazonq",
preference: "amazonq",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer q (alias for amazonq)",
preference: "q",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer cli (alias for amazonq)",
preference: "cli",
expectedFirstName: "AmazonQ",
},
{
name: "Unknown preference - no reordering",
preference: "unknown",
expectedFirstName: "AmazonQ",
},
{
name: "Empty preference - no reordering",
preference: "",
expectedFirstName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"preferred_endpoint": tt.preference,
},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != tt.expectedFirstName {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
}
})
}
}
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
// Test that preferred_endpoint can also come from Attributes
auth := &cliproxyauth.Auth{
Metadata: map[string]any{},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != "CodeWhisperer" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
}
}
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
// Metadata should take precedence
if configs[0].Name != "AmazonQ" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
}
}
func TestGetAuthValue(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
key string
expected string
}{
{
name: "From metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "From attributes (fallback)",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Metadata takes precedence",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "Key not found",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"other_key": "value"},
Attributes: map[string]string{"another_key": "value"},
},
key: "test_key",
expected: "",
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Both nil",
auth: &cliproxyauth.Auth{},
key: "test_key",
expected: "",
},
{
name: "Value is trimmed and lowercased",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
},
key: "test_key",
expected: "upper_value",
},
{
name: "Empty string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": ""},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Non-string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": 123},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAuthValue(tt.auth, tt.key)
if result != tt.expected {
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
checkFn func(t *testing.T, result string)
}{
{
name: "From client_id",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"client_id": "test-client-id-123",
"refresh_token": "test-refresh-token-456",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "From refresh_token only",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"refresh_token": "test-refresh-token-789",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "Nil auth",
auth: nil,
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Empty metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{},
},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAccountKey(tt.auth)
tt.checkFn(t, result)
})
}
}
func TestEndpointAliases(t *testing.T) {
// Verify all expected aliases are defined
expectedAliases := map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
for alias, target := range expectedAliases {
if actual, ok := endpointAliases[alias]; !ok {
t.Errorf("missing alias %q", alias)
} else if actual != target {
t.Errorf("alias %q = %q, want %q", alias, actual, target)
}
}
// Verify no unexpected aliases
if len(endpointAliases) != len(expectedAliases) {
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
}
}

Some files were not shown because too many files have changed in this diff Show More