Compare commits

..

212 Commits

Author SHA1 Message Date
Luis Pater
d3100085b0 Merge pull request #392 from router-for-me/plus
v6.8.30
2026-02-26 23:16:26 +08:00
Luis Pater
f481d25133 Merge branch 'main' into plus 2026-02-26 23:16:17 +08:00
Luis Pater
8c6c90da74 fix(registry): clean up outdated model definitions in static data 2026-02-26 23:12:40 +08:00
Luis Pater
24bcfd9c03 Merge pull request #1699 from 123hi123/fix/antigravity-primary-model-fallback
fix(antigravity): keep primary model list and backfill empty auths
2026-02-26 04:28:29 +08:00
Luis Pater
816fb4c5da Merge pull request #1682 from sususu98/fix/tool-result-image-parts
fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
2026-02-25 23:14:35 +08:00
Luis Pater
c1bb77c7c9 Merge pull request #291 from howarddong711/feat/copilot-email-name
feat(copilot): fetch and persist user email and display name on login
2026-02-25 22:23:25 +08:00
Luis Pater
6bcac3a55a Merge branch 'router-for-me:main' into main 2026-02-25 22:21:31 +08:00
Howard Dong
fc346f4537 fix(copilot): add username fallback and consistent file name prefix
- Add 'github-user' fallback in WaitForAuthorization when FetchUserInfo
  returns empty Login (fixes malformed 'github-copilot-.json' filenames)
- Standardize Web API file name to 'github-copilot-<user>.json' to match
  CLI path convention (was 'github-<user>.json')

Addresses Gemini Code Assist review comments on PR #291.
2026-02-25 17:17:51 +08:00
Howard Dong
43e531a3b6 feat(copilot): fetch and persist user email and display name on login
- Expand OAuth scope to include read:user for full profile access
- Add GitHubUserInfo struct with Login, Email, Name fields
- Update FetchUserInfo to return complete user profile
- Add Email and Name fields to CopilotTokenStorage and CopilotAuthBundle
- Fix provider string bug: 'github' -> 'github-copilot' in auth_files.go
- Fix semantic bug: email field was storing username
- Update Label to prefer email over username in both CLI and Web API paths
- Add 9 unit tests covering new functionality
2026-02-25 17:09:40 +08:00
Luis Pater
d24ea4ce2a Merge pull request #1664 from ciberponk/pr/responses-compaction-compat
feat: add codex responses compatibility for compaction payloads
2026-02-25 01:21:59 +08:00
Luis Pater
2c30c981ae Merge pull request #1687 from lyd123qw2008/fix/codex-refresh-token-reused-no-retry
fix(codex): stop retrying refresh_token_reused errors
2026-02-25 01:19:30 +08:00
Luis Pater
aa1da8a858 Merge pull request #1685 from lyd123qw2008/fix/auth-auto-refresh-interval
fix(auth): respect configured auto-refresh interval
2026-02-25 01:13:47 +08:00
Luis Pater
f1e9a787d7 Merge pull request #1676 from piexian/feat/qwen-quota-handling-clean
feat(qwen): add rate limiting and quota error handling
2026-02-25 01:07:55 +08:00
Luis Pater
4eeec297de Merge pull request #288 from router-for-me/plus
v6.8.27
2026-02-25 01:04:57 +08:00
Luis Pater
77cc4ce3a0 Merge branch 'main' into plus 2026-02-25 01:04:15 +08:00
Luis Pater
37dfea1d3f Merge pull request #287 from possible055/main
fix(kiro): support OR-group field matching in truncation detector
2026-02-25 01:02:49 +08:00
Luis Pater
e6626c672a Merge pull request #269 from ClubWeGo/fix/filterOrphanedToolResults
fix: filter out orphaned tool results from history and current context
2026-02-25 01:02:11 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
apparition
c785c1a3ca fix(kiro): support OR-group field matching in truncation detector
- Change RequiredFieldsByTool value type from []string to [][]string
- Outer slice = AND (all groups required); inner slice = OR (any one satisfies)
- Fix Bash entry to accept "cmd" or "command", resolving soft-truncation loop
- Update findMissingRequiredFields logic and inline docs accordingly
2026-02-24 22:48:05 +08:00
comalot
514ae341c8 fix(antigravity): deep copy cached model metadata 2026-02-24 20:14:01 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
comalot
8ce07f38dd fix(antigravity): keep primary model list and backfill empty auths 2026-02-24 16:16:44 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
lyd123qw2008
3b3e0d1141 test(codex): log non-retryable refresh error and cover single-attempt behavior 2026-02-23 22:41:33 +08:00
lyd123qw2008
7acd428507 fix(codex): stop retrying refresh_token_reused errors 2026-02-23 22:31:30 +08:00
lyd123qw2008
450d1227bd fix(auth): respect configured auto-refresh interval 2026-02-23 22:07:50 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
Darley
6e634fe3f9 fix: filter out orphaned tool results from history and current context 2026-02-23 14:33:59 +08:00
sususu98
4e26182d14 fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
Move base64 image data from Claude tool_result into functionResponse.parts
as inlineData instead of outer sibling parts, preventing context bloat.
Unify all inlineData field naming to camelCase mimeType across Claude,
OpenAI, and Gemini translators. Add comprehensive edge case tests and
Gemini-side regression test for functionResponse.parts preservation.
2026-02-23 13:38:21 +08:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
piexian
3b421c8181 feat(qwen): add rate limiting and quota error handling
- Add 60 requests/minute rate limiting per credential using sliding window
- Detect insufficient_quota errors and set cooldown until next day (Beijing time)
- Map quota errors (HTTP 403/429) to 429 with retryAfter for conductor integration
- Cache Beijing timezone at package level to avoid repeated syscalls
- Add redactAuthID function to protect credentials in logs
- Extract wrapQwenError helper to consolidate error handling
2026-02-23 00:38:46 +08:00
Luis Pater
21d2329947 Merge pull request #261 from router-for-me/plus
v6.8.26
2026-02-23 00:15:36 +08:00
Luis Pater
0993413bab Merge branch 'main' into plus 2026-02-23 00:15:22 +08:00
Luis Pater
713388dd7b Fixed: #1675
fix(gemini): add model definitions for Gemini 3.1 Pro High and Image
2026-02-23 00:12:57 +08:00
Luis Pater
e6c7af0fa9 Merge pull request #1522 from soilSpoon/feature/canceled
feature(proxy): Adds special handling for client cancellations in proxy error handler
2026-02-22 22:02:59 +08:00
Luis Pater
837aa6e3aa Merge branch 'router-for-me:main' into main 2026-02-22 21:52:53 +08:00
Luis Pater
d210be06c2 fix(gemini): update min Thinking value and add Gemini 3.1 Pro Preview model definition 2026-02-22 21:51:32 +08:00
fan
afc8a0f9be refactor: simplify context_management compatibility handling 2026-02-21 22:20:48 +08:00
Luis Pater
af8e9ef458 Merge branch 'router-for-me:main' into main 2026-02-21 21:09:52 +08:00
Luis Pater
cec6f993ad Merge pull request #256 from kavore/fix/oauth-copilot-claude-aliases
fix: add default copilot claude model aliases for oauth routing
2026-02-21 21:09:43 +08:00
Luis Pater
950de29f48 Merge pull request #255 from ladeng07/main
feat(registry): add GPT-4o model variants for GitHub Copilot
2026-02-21 21:09:06 +08:00
Luis Pater
d6ec33e8e1 Merge pull request #1662 from matchch/contribute/cache-user-id
feat: add cache-user-id toggle for Claude cloaking
2026-02-21 20:51:30 +08:00
Luis Pater
081cfe806e fix(gemini): correct Created timestamps for Gemini 3.1 Pro Preview model definitions 2026-02-21 20:47:47 +08:00
hkfires
c1c62a6c04 feat(gemini): add Gemini 3.1 Pro Preview model definitions 2026-02-21 20:42:29 +08:00
ciberponk
d693d7993b feat: support responses compaction payload compatibility for codex translator 2026-02-21 12:56:10 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
matchch
2fdf5d2793 feat: add cache-user-id toggle for Claude cloaking
Default to generating a fresh random user_id per request instead of
reusing cached IDs. Add cache-user-id config option to opt in to the
previous caching behavior.

- Add CacheUserID field to CloakConfig
- Extract user_id cache logic to dedicated file
- Generate fresh user_id by default, cache only when enabled
- Add tests for both paths
2026-02-21 12:31:20 +08:00
kavore
b3da00d2ed fix: add default copilot claude model aliases for oauth routing 2026-02-20 21:59:21 +03:00
LMark
740277a9f2 refactor(registry): deduplicate GitHub Copilot GPT-4o model definitions 2026-02-21 02:32:06 +08:00
LMark
f91807b6b9 Add GPT-4o model variants while keeping Gemini 3.1 Pro preview 2026-02-21 01:41:01 +08:00
Luis Pater
57d18bb226 Merge branch 'router-for-me:main' into main 2026-02-20 22:42:01 +08:00
Luis Pater
10b9c6cb8a Merge pull request #252 from DragonBaiMo/fix/kiro-thinking-stream-dedup
fix(kiro): stop duplicated thinking on OpenAI and preserve Claude multi-turn thinking
2026-02-20 22:41:48 +08:00
Luis Pater
b24786f8a7 Merge pull request #250 from TonyRL/feat/copilot-gemini-3.1
feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider
2026-02-20 22:40:41 +08:00
Luis Pater
7b0eb41ebc Merge pull request #1660 from Grivn/fix/claude-token-url
fix(claude): use api.anthropic.com for OAuth token exchange
2026-02-20 21:52:08 +08:00
DragonBaiMo
70949929db fix(kiro): deduplicate thinking stream emission 2026-02-20 20:34:40 +08:00
DragonBaiMo
7c9c89dace fix(kiro): keep thinking enabled across request formats 2026-02-20 20:34:40 +08:00
Grivn
ef5901c81b fix(claude): use api.anthropic.com for OAuth token exchange
console.anthropic.com is now protected by a Cloudflare managed challenge
that blocks all non-browser POST requests to /v1/oauth/token, causing
`-claude-login` to fail with a 403 error.

Switch to api.anthropic.com which hosts the same OAuth token endpoint
without the Cloudflare managed challenge.

Fixes #1659

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 20:11:27 +08:00
Luis Pater
d4829c82f7 Merge pull request #1652 from thebtf/fix/claude-translator-arguments
fix(translator): handle tool call arguments in codex→claude streaming translator
2026-02-20 19:50:20 +08:00
Luis Pater
a5f4166a9b Merge pull request #1644 from possible055/main
feat: add Gemini 3.1 Pro Preview model definition
2026-02-20 19:44:59 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Tony
f2b1ec4f9e feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider 2026-02-20 04:23:42 +08:00
Kirill Turanskiy
1cc21cc45b fix: prevent duplicate function call arguments when delta events precede done
Non-spark codex models (gpt-5.3-codex, gpt-5.2-codex) stream function call
arguments via multiple delta events followed by a done event. The done handler
unconditionally emitted the full arguments, duplicating what deltas already
streamed. This produced invalid double JSON that Claude Code couldn't parse,
causing tool calls to fail with missing parameters and infinite retry loops.

Add HasReceivedArgumentsDelta flag to track whether delta events were received.
The done handler now only emits arguments when no deltas preceded it (spark
models), while delta-based streaming continues to work for non-spark models.
2026-02-19 23:18:14 +03:00
Kirill Turanskiy
07cf616e2b fix: handle response.function_call_arguments.done in codex→claude streaming translator
Some Codex models (e.g. gpt-5.3-codex-spark) send function call arguments
in a single "done" event without preceding "delta" events. The streaming
translator only handled "delta" events, causing tool call arguments to be
lost — resulting in empty tool inputs and infinite retry loops in clients
like Claude Code.

Emit the full arguments from the "done" event as a single input_json_delta
so downstream clients receive the complete tool input.
2026-02-19 23:18:14 +03:00
Luis Pater
2b8c466e88 refactor(executor, handlers): replace channel-based streams with StreamResult for consistency
- Updated `ExecuteStream` functions in executors to use `StreamResult` instead of channels.
- Enhanced upstream header handling in OpenAI handlers.
- Improved maintainability and alignment across executors and handlers.
2026-02-19 22:07:14 +08:00
Luis Pater
ca2174ea48 Merge pull request #249 from router-for-me/plus
v6.8.22
2026-02-19 21:58:42 +08:00
Luis Pater
c09fb2a79d Merge branch 'main' into plus 2026-02-19 21:58:04 +08:00
Luis Pater
4445a165e9 test(handlers): add tests for passthrough headers behavior in WriteErrorResponse 2026-02-19 21:49:44 +08:00
Luis Pater
e92e2af71a Merge branch 'codex/pr-1626' into dev 2026-02-19 21:33:23 +08:00
Luis Pater
a6bdd9a652 feat: add passthrough headers configuration
- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers.
- Updated handlers to respect the passthrough headers setting.
- Added tests to verify behavior when passthrough is enabled or disabled.
2026-02-19 21:31:29 +08:00
Luis Pater
349a6349b3 Merge pull request #1645 from tinyc0der/fix/antigravity-tool-result-json
fix(antigravity): prevent invalid JSON when tool_result has no content
2026-02-19 21:01:25 +08:00
TinyCoder
00822770ec fix(antigravity): prevent invalid JSON when tool_result has no content
sjson.SetRaw with an empty string produces malformed JSON (e.g. "result":}).
This happens when a Claude tool_result block has no content field, causing
functionResponseResult.Raw to be "". Guard against this by falling back to
sjson.Set with an empty string only when .Raw is empty.
2026-02-19 17:08:39 +07:00
apparition
1a0ceda0fc feat: add Gemini 3.1 Pro Preview model definition 2026-02-19 17:43:08 +08:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
Luis Pater
72add453d2 docs: add OmniRoute to README 2026-02-19 13:23:25 +08:00
Luis Pater
2789396435 fix: ensure connection-scoped headers are filtered in upstream requests
- Added `connectionScopedHeaders` utility to respect "Connection" header directives.
- Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically.
- Refactored and tested upstream header filtering with additional validations.
- Adjusted upstream header handling during retries to replace headers safely.
2026-02-19 13:19:10 +08:00
Luis Pater
61da7bd981 Merge PR #1626 into codex/pr-1626 2026-02-19 04:49:14 +08:00
Luis Pater
ae4c502792 Merge pull request #248 from router-for-me/plus
v6.8.21
2026-02-19 04:42:44 +08:00
Luis Pater
ec6068060b Merge branch 'main' into plus 2026-02-19 04:42:35 +08:00
Luis Pater
ecb01d3dcd Merge pull request #244 from PancakeZik/feat/sonnet-4-6
feat: add Claude Sonnet 4.6 model support for Kiro provider
2026-02-19 04:31:20 +08:00
Luis Pater
22c0c00bd4 Merge branch 'main' into feat/sonnet-4-6 2026-02-19 04:31:07 +08:00
Luis Pater
9eb3e7a6c4 Merge pull request #243 from gl11tchy/feat/claude-sonnet-4-6
feat(registry): add Claude Sonnet 4.6 model definitions
2026-02-19 04:29:39 +08:00
Luis Pater
357c191510 Merge pull request #242 from ultraplan-bit/main
Improve Copilot provider based on ericc-ch/copilot-api comparison
2026-02-19 04:27:02 +08:00
Luis Pater
5db244af76 Merge pull request #240 from TonyRL/feat/copilot-sonnet-4.6
feat(registry): add Sonnet 4.6 to GitHub Copilot provider
2026-02-19 04:26:28 +08:00
Luis Pater
dc375d1b74 Merge pull request #239 from TonyRL/feat/copilot-codex-5.3
feat(registry): add GPT-5.3 Codex to GitHub Copilot provider
2026-02-19 04:25:25 +08:00
Luis Pater
9c040445af Merge pull request #1635 from thebtf/fix/openai-translator-tool-streaming
fix: handle tool call argument streaming in Codex→OpenAI translator
2026-02-19 04:22:12 +08:00
Luis Pater
fff866424e Merge pull request #1628 from thebtf/fix/masquerading-headers
fix: update Claude masquerading headers and configurable defaults
2026-02-19 04:19:59 +08:00
Luis Pater
2d12becfd6 Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping
fix: clamp reasoning_effort to valid OpenAI-format values
2026-02-19 04:15:19 +08:00
Luis Pater
252f7e0751 Merge pull request #1625 from thebtf/feat/tool-prefix-config
feat: add per-auth tool_prefix_disabled option
2026-02-19 04:07:22 +08:00
Luis Pater
b2b17528cb Merge branch 'pr-1624' into dev
# Conflicts:
#	internal/runtime/executor/claude_executor.go
#	internal/runtime/executor/claude_executor_test.go
2026-02-19 04:05:04 +08:00
Luis Pater
55f938164b Merge pull request #1618 from alexey-yanchenko/fix/completions-usage
Fix empty usage in /v1/completions
2026-02-19 03:57:11 +08:00
Luis Pater
76294f0c59 Merge pull request #1608 from thebtf/fix/tool-reference-proxy-prefix-mainline
fix: add proxy_ prefix handling for tool_reference content blocks
2026-02-19 03:50:34 +08:00
Luis Pater
2bcee78c6e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:19:18 +08:00
Luis Pater
7fe8246a9f Merge branch 'tui' into dev 2026-02-19 03:18:24 +08:00
Luis Pater
93fe58e31e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:18:08 +08:00
Luis Pater
e5b5dc870f chore(executor): remove unused Openai-Beta header from Codex executor 2026-02-19 02:19:48 +08:00
Luis Pater
a54877c023 Merge branch 'dev' 2026-02-19 02:03:41 +08:00
Luis Pater
bb86a0c0c4 feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage.
- Added WebSocket-based Codex executor to support Responses API upgrade.
- Updated middleware logic to selectively capture request bodies for memory efficiency.
- Enhanced Codex configuration handling with new WebSocket attributes.
2026-02-19 01:57:02 +08:00
Kirill Turanskiy
5fa23c7f41 fix: handle tool call argument streaming in Codex→OpenAI translator
The OpenAI Chat Completions translator was silently dropping
response.function_call_arguments.delta and
response.function_call_arguments.done Codex SSE events, meaning
tool call arguments were never streamed incrementally to clients.

Add proper handling mirroring the proven Claude translator pattern:

- response.output_item.added: announce tool call (id, name, empty args)
- response.function_call_arguments.delta: stream argument chunks
- response.function_call_arguments.done: emit full args if no deltas
- response.output_item.done: defensive fallback for backward compat

State tracking via HasReceivedArgumentsDelta and HasToolCallAnnounced
ensures no duplicate argument emission and correct behavior for models
like codex-spark that skip delta events entirely.
2026-02-18 19:09:05 +03:00
gl11tchy
f9a09b7f23 style: sort model entries per review feedback 2026-02-18 15:06:28 +00:00
Joao
b0cde626fe feat: add Claude Sonnet 4.6 model support for Kiro provider 2026-02-18 13:51:23 +00:00
gl11tchy
e42ef9a95d feat(registry): add Claude Sonnet 4.6 model definitions
Add claude-sonnet-4-6 to:
- Claude OAuth provider (model_definitions_static_data.go)
- Antigravity model config (thinking + non-thinking entries)
- GitHub Copilot provider (model_definitions.go)

Ref: https://docs.anthropic.com/en/docs/about-claude/models
2026-02-18 13:43:22 +00:00
ultraplan-bit
abf1629ec7 Merge branch 'main' of https://github.com/ultraplan-bit/CLIProxyAPIPlus 2026-02-18 08:56:06 +08:00
Kirill Turanskiy
73dc0b10b8 fix: update Claude masquerading headers and make them configurable
Update hardcoded X-Stainless-* and User-Agent defaults to match
Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (verified via
diagnostic proxy capture 2026-02-17).

Changes:
- X-Stainless-Os/Arch: dynamic via runtime.GOOS/GOARCH
- X-Stainless-Package-Version: 0.55.1 → 0.74.0
- X-Stainless-Timeout: 60 → 600
- User-Agent: claude-cli/1.0.83 (external, cli) → claude-cli/2.1.44 (external, sdk-cli)

Add claude-header-defaults config section so values can be updated
without recompilation when Claude Code releases new versions.
2026-02-18 03:38:51 +03:00
Kirill Turanskiy
2ea95266e3 fix: clamp reasoning_effort to valid OpenAI-format values
CPA-internal thinking levels like 'xhigh' and 'minimal' are not accepted
by OpenAI-format providers (MiniMax, etc.). The OpenAI applier now maps
non-standard levels to the nearest valid reasoning_effort value before
writing to the request body:

  xhigh   → high
  minimal → low
  auto    → medium
2026-02-18 03:36:42 +03:00
Tony
922d4141c0 feat(registry): add Sonnet 4.6 to GitHub Copilot provider 2026-02-18 05:17:23 +08:00
Kirill Turanskiy
1f8f198c45 feat: passthrough upstream response headers to clients
CPA previously stripped ALL response headers from upstream AI provider
APIs, preventing clients from seeing rate-limit info, request IDs,
server-timing and other useful headers.

Changes:
- Add Headers field to Response and StreamResult structs
- Add FilterUpstreamHeaders helper (hop-by-hop + security denylist)
- Add WriteUpstreamHeaders helper (respects CPA-set headers)
- ExecuteWithAuthManager/ExecuteCountWithAuthManager now return headers
- ExecuteStreamWithAuthManager returns headers from initial connection
- All 11 provider executors populate Response.Headers
- All handler call sites write filtered upstream headers before response

Filtered headers (not forwarded):
- RFC 7230 hop-by-hop: Connection, Transfer-Encoding, Keep-Alive, etc.
- Security: Set-Cookie
- CPA-managed: Content-Length, Content-Encoding
2026-02-18 00:16:22 +03:00
Tony
c55275342c feat(registry): add GPT-5.3 Codex to GitHub Copilot provider 2026-02-18 03:04:27 +08:00
Kirill Turanskiy
9261b0c20b feat: add per-auth tool_prefix_disabled option
Allow disabling the proxy_ tool name prefix on a per-account basis.
Users who route their own Anthropic account through CPA can set
"tool_prefix_disabled": true in their OAuth auth JSON to send tool
names unchanged to Anthropic.

Default behavior is fully preserved — prefix is applied unless
explicitly disabled.

Changes:
- Add ToolPrefixDisabled() accessor to Auth (reads metadata key
  "tool_prefix_disabled" or "tool-prefix-disabled")
- Gate all 6 prefix apply/strip points with the new flag
- Add unit tests for the accessor
2026-02-17 21:48:19 +03:00
Kirill Turanskiy
7cc725496e fix: skip proxy_ prefix for built-in tools in message history
The proxy_ prefix logic correctly skips built-in tools (those with a
non-empty "type" field) in tools[] definitions but does not skip them
in messages[].content[] tool_use blocks or tool_choice. This causes
web_search in conversation history to become proxy_web_search, which
Anthropic does not recognize.

Fix: collect built-in tool names from tools[] into a set and also
maintain a hardcoded fallback set (web_search, code_execution,
text_editor, computer) for cases where the built-in tool appears in
history but not in the current request's tools[] array. Skip prefixing
in messages and tool_choice when name matches a built-in.
2026-02-17 21:42:32 +03:00
ultraplan-bit
5726a99c80 Improve Copilot provider based on ericc-ch/copilot-api comparison
- Fix X-Initiator detection: check for any assistant/tool role
  in messages instead of only the last message role, matching
  the correct agent detection for multi-turn tool conversations
- Add x-github-api-version: 2025-04-01 header for API compatibility
- Support Business/Enterprise accounts by using Endpoints.API from
  the Copilot token response instead of hardcoded base URL
- Fix Responses API vision detection: detect vision content before
  input normalization removes the messages array
- Add 8 test cases covering the above fixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:11:17 +08:00
ultraplan-bit
b5756bf729 Fix Copilot 0x model incorrectly consuming premium requests
Change Openai-Intent header from "conversation-edits" to
"conversation-panel" to avoid triggering GitHub's premium
execution path, which caused included models (0x multiplier)
to be billed as premium requests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 21:17:18 +08:00
Alexey Yanchenko
709d999f9f Add usage to /v1/completions 2026-02-17 17:21:03 +07:00
Kirill Turanskiy
24c18614f0 fix: skip built-in tools in tool_reference prefix + refactor to switch
- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
2026-02-16 19:37:11 +03:00
Kirill Turanskiy
603f06a762 fix: handle tool_reference nested inside tool_result.content[]
tool_reference blocks can appear nested inside tool_result.content[]
arrays, not just at the top level of messages[].content[]. The prefix
logic now iterates into tool_result blocks with array content to find
and prefix/strip nested tool_reference.tool_name fields.
2026-02-16 19:06:24 +03:00
Kirill Turanskiy
98f0a3e3bd fix: add proxy_ prefix handling for tool_reference content blocks (#1)
applyClaudeToolPrefix, stripClaudeToolPrefixFromResponse, and
stripClaudeToolPrefixFromStreamLine now handle "tool_reference" blocks
(field "tool_name") in addition to "tool_use" blocks (field "name").

Without this fix, tool_reference blocks in conversation history retain
their original unprefixed names while tool definitions carry the proxy_
prefix, causing Anthropic API 400 errors: "Tool reference 'X' not found
in available tools."

Co-authored-by: Kirill Turanskiy <kt@novamedia.ru>
2026-02-16 19:06:24 +03:00
Luis Pater
e186ccb0d4 Merge pull request #234 from detroittommy879/feature/add-kilocode-provider
Add Kilo Code provider with dynamic model fetching
2026-02-16 23:54:29 +08:00
Luis Pater
8fc0b08b70 Merge pull request #233 from ultraplan-bit/fix/copilot-codex-responses-translation
Fix Copilot codex model Responses API translation for Claude Code
2026-02-16 23:51:42 +08:00
Luis Pater
52a257dc24 Merge pull request #237 from router-for-me/plus
v6.8.18
2026-02-16 23:50:00 +08:00
Luis Pater
a12d907f55 Merge branch 'main' into plus 2026-02-16 23:49:50 +08:00
Luis Pater
453aaf8774 chore(runtime): update Qwen executor user agent and headers for compatibility with new runtime standards 2026-02-16 23:29:47 +08:00
Supra4E8C
1b1ab1fb9b Merge pull request #1606 from router-for-me/add-qwen-3.5
feat(registry): add Qwen 3.5 Plus model definitions
2026-02-16 23:10:53 +08:00
Supra4E8C
a9d0bb72da feat(registry): add Qwen 3.5 Plus model definitions 2026-02-16 22:55:37 +08:00
DetroitTommy
d328e54e4b refactor(kilo): address code review suggestions for robustness 2026-02-15 17:26:29 -05:00
DetroitTommy
5a7932cba4 Added Kilo Code as a provider, with auth. It fetches the free models, tested them (works), for paid models someone will have to experiment so only the free ones are known to work 2026-02-15 14:54:20 -05:00
DetroitTommy
1dbeb0827a added kilocode auth, needs adjusting 2026-02-15 13:44:26 -05:00
lhpqaq
2c8821891c fix(tui): update with review 2026-02-16 00:24:25 +08:00
haopeng
0a2555b0f3 Update internal/tui/auth_tab.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-16 00:11:31 +08:00
lhpqaq
020df41efe chore(tui): update readme, fix usage 2026-02-16 00:04:04 +08:00
ultraplan-bit
f8f8cf17ce Fix Copilot codex model Responses API translation for Claude Code
- Add response.function_call_arguments.delta handler for tool call parameters
- Rewrite normalizeGitHubCopilotResponsesInput to produce structured input
  array (message/function_call/function_call_output) instead of flattened
  text, fixing infinite loop in multi-turn tool-use conversations
- Skip flattenAssistantContent for messages containing tool_use blocks,
  preventing function_call items from being destroyed
- Add reasoning/thinking stream & non-stream support
- Fix stop_reason mapping (max_tokens/stop) and cached token reporting
- Update test to match new array-based input format

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:04:45 +08:00
lhpqaq
f31f7f701a feat(tui): add i18n 2026-02-15 15:42:59 +08:00
Supra4E8C
b5fe78eb70 Merge pull request #1597 from router-for-me/kimi-fix
feat(registry): add support for 'kimi' channel in model definitions
2026-02-15 15:35:17 +08:00
Supra4E8C
d1f667cf8d feat(registry): add support for 'kimi' channel in model definitions 2026-02-15 15:21:33 +08:00
lhpqaq
54ad7c1b6b feat(tui): add manager tui 2026-02-15 14:52:40 +08:00
Luis Pater
d560c20c26 Merge branch 'router-for-me:main' into main 2026-02-15 14:49:13 +08:00
Luis Pater
5abeca1f9e Merge pull request #231 from ChrAlpha/main
feat(models): add Thinking support to GitHub Copilot models
2026-02-15 14:48:04 +08:00
Luis Pater
294eac3a88 Merge branch 'main' into main 2026-02-15 14:47:52 +08:00
Luis Pater
a31104020c Merge pull request #230 from ultraplan-bit/main
fix(copilot): forward Claude-format tools to Copilot Responses API
2026-02-15 14:45:27 +08:00
Luis Pater
65bec4d734 Merge pull request #229 from Buywatermelon/fix/issue-222-kiro-alias-deletion
fix: preserve explicitly deleted kiro aliases across config reload
2026-02-15 14:42:42 +08:00
Luis Pater
edb2993838 Merge pull request #228 from xilu0/fix/antigravity-fetch-models-logging
fix(antigravity): add warn-level logging to silent failure paths in FetchAntigravityModels
2026-02-15 14:42:13 +08:00
Luis Pater
c0d8e0dec7 Merge pull request #226 from Skyuno/refactor/websearch-alignment
refactor(kiro): Kiro Web Search Logic & Executor Alignment
2026-02-15 14:41:46 +08:00
ChrAlpha
795da13d5d feat(tests): add comprehensive GitHub Copilot tests for reasoning effort levels 2026-02-15 06:40:52 +00:00
Luis Pater
55789df275 chore(docker): update Go base image to 1.26-alpine 2026-02-15 14:26:44 +08:00
ChrAlpha
9e652a3540 fix(github-copilot): remove 'xhigh' level from Thinking support 2026-02-15 06:12:08 +00:00
Luis Pater
46a6782065 refactor(all): replace manual pointer assignments with new to enhance code readability and maintainability 2026-02-15 14:10:10 +08:00
Luis Pater
c359f61859 fix(auth): normalize Gemini credential file prefix for consistency 2026-02-15 13:59:33 +08:00
Luis Pater
908c8eab5b Merge pull request #1543 from sususu98/feat/gemini-cli-google-one
feat(gemini-cli): add Google One login and improve auto-discovery
2026-02-15 13:58:21 +08:00
Luis Pater
f5f2c69233 Merge pull request #1595 from alexey-yanchenko/feature/cache-usage-from-codex-to-chat-completions
Pass cache usage from codex to openai chat completions
2026-02-15 13:56:46 +08:00
Alexey Yanchenko
63d4de5eea Pass cache usage from codex to openai chat completions 2026-02-15 12:04:15 +07:00
ChrAlpha
af15083496 feat(models): add Thinking support to GitHub Copilot models
Enhance the model definitions by introducing Thinking support with various levels for each model.
2026-02-15 03:16:08 +00:00
ultraplan-bit
c4722e42b1 fix(copilot): forward Claude-format tools to Copilot Responses API
The normalizeGitHubCopilotResponsesTools filter required type="function",
which dropped Claude-format tools (no type field, uses input_schema).
Relax the filter to accept tools without a type field and map input_schema
to parameters so tools are correctly sent to the upstream API.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:58:15 +08:00
Dave
f9a991365f Update internal/runtime/executor/antigravity_executor.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-14 10:56:36 +08:00
y
6df16bedba fix: preserve explicitly deleted kiro aliases across config reload (#222)
The delete handler now sets the channel value to nil instead of removing
the map key, and the sanitization loop preserves nil/empty channel entries
as 'disabled' markers.  This prevents SanitizeOAuthModelAlias from
re-injecting default kiro aliases after a user explicitly deletes them
through the management API.
2026-02-14 09:40:05 +08:00
Skyuno
632a2fd2f2 refactor: align GenerateSearchIndicatorEvents return type with other event builders
Change GenerateSearchIndicatorEvents to return [][]byte instead of []sseEvent
for consistency with BuildFallbackTextEvents and other event building functions.

Benefits:
- Consistent API across all event generation functions
- Eliminates intermediate sseEvent type conversion in caller
- Simplifies usage by returning ready-to-send SSE byte slices

This addresses the code quality feedback from PR #226 review.
2026-02-13 22:04:09 +08:00
Skyuno
5626637fbd security: remove query content from web search logs to prevent PII leakage
- Remove search query from iteration logs (Info level)
- Remove query and toolUseId from analysis logs (Info level)
- Remove query from non-stream result logs (Info level)
- Remove query from tool injection logs (Info level)
- Remove query from tool_use detection logs (Debug level)

This addresses the security concern raised in PR #226 review about
potential PII exposure in search query logs.
2026-02-13 22:04:09 +08:00
Skyuno
2db89211a9 kiro: use payloadRequestedModel for response model name
Align Kiro executor with all other executors (Claude, Gemini, OpenAI,
etc.) by using payloadRequestedModel(opts, req.Model) instead of
req.Model when constructing response model names.

This ensures model aliases are correctly reflected in responses:
- Execute: BuildClaudeResponse + TranslateNonStream
- ExecuteStream: streamToChannel
- handleWebSearchStream: BuildClaudeMessageStartEvent
- handleWebSearch: via executeNonStreamFallback (automatic)

Previously Kiro was the only executor using req.Model directly,
which exposed internal routed names instead of the user's alias.
2026-02-13 22:04:09 +08:00
Skyuno
587371eb14 refactor: align web search with executor layer patterns
Consolidate web search handler, SSE event generation, stream analysis,
and MCP HTTP I/O into the executor layer. Merge the separate
kiro_websearch_handler.go back into kiro_executor.go to align with
the single-file-per-executor convention. Translator retains only pure
data types, detection, and payload transformation.

Key changes:
- Move SSE construction (search indicators, fallback text, message_start)
  from translator to executor, consistent with streamToChannel pattern
- Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription)
  from translator to executor alongside other HTTP I/O
- Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication)
- Centralize MCP endpoint URL via BuildMcpEndpoint in translator
- Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache
- Thread context.Context through MCP HTTP calls for cancellation support
- Thread usage reporter through all web search API call paths
- Add token expiry pre-check before MCP/GAR calls
- Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic,
  ContainsWebSearchTool, StripWebSearchTool)
2026-02-13 22:04:09 +08:00
xiluo
75818b1e25 fix(antigravity): add warn-level logging to silent failure paths in FetchAntigravityModels
Add log.Warnf calls to all 7 silent return nil paths so operators can
diagnose why specific antigravity accounts fail to fetch models and get
unregistered without any log trail.

Covers: token errors, request creation failures, context cancellation,
network errors (after exhausting fallback URLs), body read errors,
unexpected HTTP status codes, and missing models field in response.
2026-02-13 18:01:46 +08:00
이대희
a45c6defa7 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-13 15:07:32 +09:00
Luis Pater
cbe56955a9 Merge pull request #227 from router-for-me/plus
v6.8.15
2026-02-13 12:50:52 +08:00
Luis Pater
8ea6ac913d Merge branch 'main' into plus 2026-02-13 12:50:39 +08:00
Luis Pater
ae1e8a5191 chore(runtime, registry): update Codex client version and GPT-5.3 model creation date 2026-02-13 12:47:48 +08:00
Luis Pater
b3ccc55f09 Merge pull request #1574 from fbettag/feat/gpt-5.3-codex-spark
feat(registry): add gpt-5.3-codex-spark model definition
2026-02-13 12:46:08 +08:00
이대희
40bee3e8d9 Merge branch 'main' into feature/canceled 2026-02-13 13:37:55 +09:00
Franz Bettag
1ce56d7413 Update internal/registry/model_definitions_static_data.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 23:37:27 +01:00
Franz Bettag
41a78be3a2 feat(registry): add gpt-5.3-codex-spark model definition 2026-02-12 23:24:08 +01:00
Luis Pater
1ff5de9a31 docs(readme): add CLIProxyAPI Dashboard to project list 2026-02-13 00:40:39 +08:00
Luis Pater
46a6853046 Merge pull request #1568 from itsmylife44/add-cliproxyapi-dashboard
Add CLIProxyAPI Dashboard to 'Who is with us?' section
2026-02-13 00:37:41 +08:00
xSpaM
4b2d40bd67 Add CLIProxyAPI Dashboard to 'Who is with us?' section 2026-02-12 17:15:46 +01:00
Luis Pater
726f1a590c Merge branch 'router-for-me:main' into main 2026-02-12 22:43:44 +08:00
Luis Pater
575881cb59 feat(registry): add new model definition for MiniMax-M2.5 2026-02-12 22:43:01 +08:00
Luis Pater
d02df0141b Merge pull request #224 from Buywatermelon/fix/kiro-assistant-first-message
fix(kiro): prepend placeholder user message when conversation starts with assistant role
2026-02-12 15:11:10 +08:00
Luis Pater
e4bc9da913 Merge pull request #220 from jellyfish-p/main
fix(kiro): 修复之前提交的错误的application/cbor请求处理逻辑
2026-02-12 15:10:42 +08:00
Luis Pater
8c6be49625 Merge pull request #218 from ClubWeGo/fix/merge-assistant-tool-calls
fix: prevent merging assistant messages with tool_calls
2026-02-12 15:10:00 +08:00
Luis Pater
c727e4251f ci(github): trigger Docker image workflow on version tags matching v* 2026-02-12 15:09:16 +08:00
Luis Pater
99266be998 Merge pull request #216 from starsdream666/main
增加kiro新模型并根据其他提供商同模型配置Thinking
2026-02-12 15:08:37 +08:00
y
086d8d0d0b fix(kiro): prepend placeholder user message when conversation starts with assistant role
Kiro/AmazonQ API requires the conversation history to start with a user message.
Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
which is valid for the native Claude API but causes 'Improperly formed request' (400)
on the Kiro endpoint.

This fix detects when the first message has role=assistant and prepends a minimal
placeholder user message ('.') to satisfy the Kiro API's message ordering requirement.

Upstream error: {"message":"Improperly formed request.","reason":null}
Verified: original request returns 400, fixed request returns 200.
2026-02-12 11:09:47 +08:00
jellyfish-p
627dee1dac fix(kiro): 修复之前提交的错误的application/cbor请求处理逻辑 2026-02-12 09:57:34 +08:00
이대희
93147dddeb Improves error handling for canceled requests
Adds explicit handling for context.Canceled errors in the reverse proxy error handler to return 499 status code without logging, which is more appropriate for client-side cancellations during polling.

Also adds a test case to verify this behavior.
2026-02-12 10:39:45 +09:00
이대희
c0f9b15a58 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-12 10:33:49 +09:00
이대희
6f2fbdcbae Update internal/api/modules/amp/proxy.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 10:30:05 +09:00
Darley
55c3197fb8 fix(kiro): merge adjacent assistant messages while preserving tool_calls 2026-02-12 07:30:36 +08:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
Darley
5a2cf0d53c fix: prevent merging assistant messages with tool_calls
Adjacent assistant messages where any message contains tool_calls
were being merged by MergeAdjacentMessages, causing tool_calls to
be silently dropped. This led to orphaned tool results that could
not match any toolUse in history, resulting in Kiro API returning
'Improperly formed request.'

Now assistant messages with tool_calls are kept separate during
merge, preserving the tool call chain integrity.
2026-02-12 01:53:40 +08:00
starsdream666
2573358173 根据其他提供商同模型配置Thinking 2026-02-12 00:41:13 +08:00
starsdream666
09cd3cff91 增加kiro新模型:deepseek-3.2,minimax-m2.1,qwen3-coder-next,gpt-4o,gpt-4,gpt-4-turbo,gpt-3.5-turbo 2026-02-12 00:35:24 +08:00
starsdream666
ab0bf1b517 Merge branch 'router-for-me:main' into main 2026-02-11 16:20:20 +00:00
starsdream666
544238772a Merge branch 'router-for-me:main' into main 2026-02-11 10:58:06 +00:00
sususu98
f3ccd85ba1 feat(gemini-cli): add Google One login and improve auto-discovery
Add Google One personal account login to Gemini CLI OAuth flow:
- CLI --login shows mode menu (Code Assist vs Google One)
- Web management API accepts project_id=GOOGLE_ONE sentinel
- Auto-discover project via onboardUser without cloudaicompanionProject when project is unresolved

Improve robustness of auto-discovery and token handling:
- Add context-aware auto-discovery polling (30s timeout, 2s interval)
- Distinguish network errors from project-selection-required errors
- Refresh expired access tokens in readAuthFile before project lookup
- Extend project_id auto-fill to gemini auth type (was antigravity-only)

Unify credential file naming to geminicli- prefix for both CLI and web.

Add extractAccessToken unit tests (9 cases).
2026-02-11 17:53:03 +08:00
starsdream666
9c65e17a21 Merge branch 'router-for-me:main' into main 2026-02-10 14:41:20 +00:00
이대희
ce0c6aa82b Update internal/api/modules/amp/proxy.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-10 19:07:49 +09:00
이대희
3c85d2a4d7 feature(proxy): Adds special handling for client cancellations in proxy error handler
Silences logging for client cancellations during polling to reduce noise in logs.
Client-side cancellations are common during long-running operations and should not be treated as errors.
2026-02-10 18:02:08 +09:00
starsdream666
15bc99f6ea Merge branch 'router-for-me:main' into main 2026-02-10 01:45:05 +00:00
starsdream666
3ec7991e5f Merge branch 'router-for-me:main' into main 2026-02-09 14:18:04 +00:00
starsdream666
40e85a6759 Merge branch 'router-for-me:main' into main 2026-02-07 16:37:51 +00:00
starsdream666
cc116ce67d Merge branch 'router-for-me:main' into main 2026-02-06 16:11:26 +00:00
starsdream666
40efc2ba43 修改工作流 2026-02-06 03:29:31 +00:00
157 changed files with 16575 additions and 1892 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:
go-version: '>=1.24.0'
go-version: '>=1.26.0'
cache: true
- name: Generate Build Metadata
run: |

3
.gitignore vendored
View File

@@ -3,10 +3,11 @@ cli-proxy-api
cliproxy
*.exe
# Configuration
config.yaml
.env
.mcp.json
# Generated content
bin/*
logs/*

View File

@@ -1,4 +1,4 @@
FROM golang:1.24-alpine AS builder
FROM golang:1.26-alpine AS builder
WORKDIR /app

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
@@ -26,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -70,8 +72,10 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
@@ -88,14 +92,18 @@ func main() {
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
var noIncognito bool
var useIncognito bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
@@ -114,6 +122,8 @@ func main() {
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -494,11 +504,16 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
} else if kiloLogin {
cmd.DoKiloLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
@@ -536,15 +551,89 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
// 初始化并启动 Kiro token 后台刷新
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
cmd.StartService(cfg, configFilePath, password)
}
}

View File

@@ -1,6 +1,6 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
host: ''
# Server port
port: 8317
@@ -8,8 +8,8 @@ port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
cert: ''
key: ''
# Management API settings
remote-management:
@@ -20,22 +20,22 @@ remote-management:
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
secret-key: ""
secret-key: ''
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
# Authentication directory (supports ~ for home directory)
auth-dir: "~/.cli-proxy-api"
auth-dir: '~/.cli-proxy-api'
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
- 'your-api-key-1'
- 'your-api-key-2'
- 'your-api-key-3'
# Enable debug logging
debug: false
@@ -43,7 +43,7 @@ debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
addr: '127.0.0.1:8316'
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
@@ -68,11 +68,15 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
proxy-url: ''
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
@@ -86,7 +90,7 @@ quota-exceeded:
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
strategy: 'round-robin' # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
@@ -160,6 +164,15 @@ nonstream-keepalive-interval: 0
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
@@ -171,6 +184,21 @@ nonstream-keepalive-interval: 0
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
# Kilocode (OAuth-based code assistant)
# Note: Kilocode uses OAuth device flow authentication.
# Use the CLI command: ./server --kilo-login
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
# oauth-model-alias:
# kilo:
# - name: "minimax/minimax-m2.5:free"
# alias: "minimax-m2.5"
# - name: "z-ai/glm-5:free"
# alias: "glm-5"
# oauth-excluded-models:
# kilo:
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
# - "*:free" # wildcard matching suffix (e.g. all free models)
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.

View File

@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
return clipexec.Response{}, errors.New("count tokens not implemented")
}
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
ch := make(chan clipexec.StreamChunk, 1)
go func() {
defer close(ch)
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
}()
return ch, nil
return &clipexec.StreamResult{Chunks: ch}, nil
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {

View File

@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}

23
go.mod
View File

@@ -1,9 +1,13 @@
module github.com/router-for-me/CLIProxyAPI/v6
go 1.24.0
go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/fxamacker/cbor/v2 v2.9.0
github.com/gin-gonic/gin v1.10.1
@@ -33,8 +37,16 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
@@ -42,6 +54,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -58,19 +71,27 @@ require (
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect

45
go.sum
View File

@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -1,6 +1,7 @@
package management
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
if useCBORPayload {
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
if errEncode != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
return
}
requestBody = bytes.NewReader(cborPayload)
} else {
requestBody = strings.NewReader(body.Data)
}
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
return
}
// For CBOR upstream responses, decode into plain text or JSON string before returning.
responseBodyText := string(respBody)
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
responseBodyText = decodedBody
}
}
response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
Body: responseBodyText,
}
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
return nil
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
if len(headers) == 0 {
return false
}
for key, value := range headers {
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
continue
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
return true
}
}
return false
}
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
var payload any
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
return nil, errUnmarshal
}
return cbor.Marshal(payload)
}
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
if len(raw) == 0 {
return "", nil
}
var payload any
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
return "", errUnmarshal
}
jsonCompatible := cborValueToJSONCompatible(payload)
switch typed := jsonCompatible.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
default:
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
if errMarshal != nil {
return "", errMarshal
}
return string(jsonBytes), nil
}
}
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
func cborValueToJSONCompatible(value any) any {
switch typed := value.(type) {
case map[any]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
}
return out
case map[string]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[key] = cborValueToJSONCompatible(item)
}
return out
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = cborValueToJSONCompatible(item)
}
return out
default:
return typed
}
}
// QuotaDetail represents quota information for a specific resource type
type QuotaDetail struct {
Entitlement float64 `json:"entitlement"`

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -813,6 +814,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -869,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1018,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1193,6 +1282,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
ts.ProjectID = strings.Join(projects, ",")
ts.Checked = true
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed")
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Google One auto-discovery returned empty project ID")
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
return
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
@@ -1252,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1397,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1561,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1616,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1692,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -1811,8 +1929,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
@@ -1826,7 +1942,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
@@ -1838,9 +1954,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
@@ -1849,18 +1969,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
}
@@ -1874,7 +2002,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
@@ -2124,7 +2252,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
@@ -2374,6 +2543,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {
@@ -2668,3 +2845,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
return verifier, challenge, nil
}
func (h *Handler) RequestKiloToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing Kilo authentication...")
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
kilocodeAuth := kilo.NewKiloAuth()
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
RegisterOAuthSession(state, "kilo")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", err)
return
}
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
log.Warnf("Failed to fetch profile: %v", err)
profile = &kilo.Profile{Email: status.UserEmail}
}
var orgID string
if len(profile.Orgs) > 0 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
record := &coreauth.Auth{
ID: fileName,
Provider: "kilo",
FileName: fileName,
Storage: ts,
Metadata: map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kilo")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": resp.VerificationURL,
"state": state,
"user_code": resp.Code,
"verification_uri": resp.VerificationURL,
})
}

View File

@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
c.JSON(200, gin.H{})
return
}
cfgCopy := *h.cfg
c.JSON(200, &cfgCopy)
c.JSON(200, new(*h.cfg))
}
type releaseInfo struct {

View File

@@ -796,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
// Set to nil instead of deleting the key so that the "explicitly disabled"
// marker survives config reload and prevents SanitizeOAuthModelAlias from
// re-injecting default aliases (fixes #222).
h.cfg.OAuthModelAlias[channel] = nil
h.persist(c)
}

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -15,10 +15,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
if c.Request.Method == http.MethodGet {
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c)
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture request body
var body []byte
if c.Request.Body != nil {
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {

View File

@@ -0,0 +1,138 @@
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
bodyStr := string(w.requestInfo.Body)
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}

View File

@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
m.modelMapper = NewModelMapper(settings.ModelMappings)
// Store initial config for partial reload comparison
settingsCopy := settings
m.lastConfig = &settingsCopy
m.lastConfig = new(settings)
// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)

View File

@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Don't log as error for context canceled - it's usually client closing connection
if errors.Is(err, context.Canceled) {
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
return
} else {
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
}

View File

@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -52,6 +52,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes
@@ -285,8 +296,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available.
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
@@ -329,6 +341,7 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
@@ -642,6 +655,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
@@ -649,6 +663,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)

View File

@@ -20,7 +20,7 @@ import (
// OAuth configuration constants for Claude/Anthropic
const (
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {

View File

@@ -0,0 +1,44 @@
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
}
// Fetch the GitHub username
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if err != nil {
log.Warnf("copilot: failed to fetch user info: %v", err)
username = "unknown"
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
return &CopilotAuthBundle{
TokenData: tokenData,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
}, nil
}
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
return false, "", nil
}
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
if err != nil {
return false, "", err
}
return true, username, nil
return true, userInfo.Login, nil
}
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
Username: bundle.Username,
Email: bundle.Email,
Name: bundle.Name,
Type: "github-copilot",
}
}

View File

@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("scope", "user:email")
data.Set("scope", "read:user user:email")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
}, nil
}
// FetchUserInfo retrieves the GitHub username for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
// GitHubUserInfo holds GitHub user profile information.
type GitHubUserInfo struct {
// Login is the GitHub username.
Login string
// Email is the primary email address (may be empty if not public).
Email string
// Name is the display name.
Name string
}
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
if accessToken == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
resp, err := c.httpClient.Do(req)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var userInfo struct {
var raw struct {
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
if userInfo.Login == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
if raw.Login == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
}
return userInfo.Login, nil
return GitHubUserInfo{
Login: raw.Login,
Email: raw.Email,
Name: raw.Name,
}, nil
}

View File

@@ -0,0 +1,213 @@
package copilot
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// roundTripFunc lets us inject a custom transport for testing.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
// regardless of the original URL host.
func newTestClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.URL.Scheme = "http"
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
return srv.Client().Transport.RoundTrip(req2)
}),
}
}
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
func TestFetchUserInfo_FullProfile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"login": "octocat",
"email": "octocat@github.com",
"name": "The Octocat",
})
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "octocat" {
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
}
if info.Email != "octocat@github.com" {
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
}
if info.Name != "The Octocat" {
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
}
}
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// GitHub returns null for private emails.
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "privateuser" {
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
}
if info.Email != "" {
t.Errorf("Email: got %q, want empty string", info.Email)
}
if info.Name != "Private User" {
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
}
}
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
func TestFetchUserInfo_EmptyToken(t *testing.T) {
client := &DeviceFlowClient{httpClient: http.DefaultClient}
_, err := client.FetchUserInfo(context.Background(), "")
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
}
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "test-token")
if err == nil {
t.Fatal("expected error for empty login, got nil")
}
}
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
func TestFetchUserInfo_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("expected error for 401 response, got nil")
}
}
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
TokenType: "bearer",
Scope: "read:user user:email",
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
if _, ok := out[key]; !ok {
t.Errorf("expected key %q in JSON output, not found", key)
}
}
if out["email"] != "octocat@github.com" {
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
}
if out["name"] != "The Octocat" {
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
}
}
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
Username: "octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if _, ok := out["email"]; ok {
t.Error("email key should be omitted when empty (omitempty), but was present")
}
if _, ok := out["name"]; ok {
t.Error("name key should be omitted when empty (omitempty), but was present")
}
}
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
bundle := &CopilotAuthBundle{
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if bundle.Email != "octocat@github.com" {
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
}
if bundle.Name != "The Octocat" {
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
}
}
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
func TestGitHubUserInfo_Struct(t *testing.T) {
info := GitHubUserInfo{
Login: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if info.Login == "" || info.Email == "" || info.Name == "" {
t.Error("GitHubUserInfo fields should not be empty")
}
}

View File

@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
ExpiresAt string `json:"expires_at,omitempty"`
// Username is the GitHub username associated with this token.
Username string `json:"username"`
// Email is the GitHub email address associated with this token.
Email string `json:"email,omitempty"`
// Name is the GitHub display name associated with this token.
Name string `json:"name,omitempty"`
// Type indicates the authentication provider type, always "github-copilot" for this storage.
Type string `json:"type"`
}
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
TokenData *CopilotTokenData
// Username is the GitHub username.
Username string
// Email is the GitHub email address.
Email string
// Name is the GitHub display name.
Name string
}
// DeviceCodeResponse represents GitHub's device code response.

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -0,0 +1,168 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
const (
// BaseURL is the base URL for the Kilo AI API.
BaseURL = "https://api.kilo.ai/api"
)
// DeviceAuthResponse represents the response from initiating device flow.
type DeviceAuthResponse struct {
Code string `json:"code"`
VerificationURL string `json:"verificationUrl"`
ExpiresIn int `json:"expiresIn"`
}
// DeviceStatusResponse represents the response when polling for device flow status.
type DeviceStatusResponse struct {
Status string `json:"status"`
Token string `json:"token"`
UserEmail string `json:"userEmail"`
}
// Profile represents the user profile from Kilo AI.
type Profile struct {
Email string `json:"email"`
Orgs []Organization `json:"organizations"`
}
// Organization represents a Kilo AI organization.
type Organization struct {
ID string `json:"id"`
Name string `json:"name"`
}
// Defaults represents default settings for an organization or user.
type Defaults struct {
Model string `json:"model"`
}
// KiloAuth provides methods for handling the Kilo AI authentication flow.
type KiloAuth struct {
client *http.Client
}
// NewKiloAuth creates a new instance of KiloAuth.
func NewKiloAuth() *KiloAuth {
return &KiloAuth{
client: &http.Client{Timeout: 30 * time.Second},
}
}
// InitiateDeviceFlow starts the device authentication flow.
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
}
var data DeviceAuthResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return &data, nil
}
// PollForToken polls for the device flow completion.
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var data DeviceStatusResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
switch data.Status {
case "approved":
return &data, nil
case "denied", "expired":
return nil, fmt.Errorf("device flow %s", data.Status)
case "pending":
continue
default:
return nil, fmt.Errorf("unknown status: %s", data.Status)
}
}
}
}
// GetProfile fetches the user's profile.
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
if err != nil {
return nil, fmt.Errorf("failed to create get profile request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
}
var profile Profile
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
return nil, err
}
return &profile, nil
}
// GetDefaults fetches default settings for an organization.
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
url := BaseURL + "/defaults"
if orgID != "" {
url = BaseURL + "/organizations/" + orgID + "/defaults"
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
}
var defaults Defaults
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
return nil, err
}
return &defaults, nil
}

View File

@@ -0,0 +1,60 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// KiloTokenStorage stores token information for Kilo AI authentication.
type KiloTokenStorage struct {
// Token is the Kilo access token.
Token string `json:"kilocodeToken"`
// OrganizationID is the Kilo organization ID.
OrganizationID string `json:"kilocodeOrganizationId"`
// Model is the default model to use.
Model string `json:"kilocodeModel"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Type indicates the authentication provider type, always "kilo" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "kilo"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("failed to close file: %v", errClose)
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used to persist Kilo credentials.
func CredentialFileName(email string) string {
return fmt.Sprintf("kilo-%s.json", email)
}

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
if err != nil {
var authErr *claude.AuthenticationError
if errors.As(err, &authErr) {
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
log.Error(claude.GetUserFriendlyMessage(authErr))
if authErr.Type == claude.ErrPortInUse.Type {
os.Exit(claude.ErrPortInUse.Code)

View File

@@ -22,6 +22,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
)
return manager
}

View File

@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
if err != nil {
var emailErr *sdkAuth.EmailRequiredError
if errors.As(err, &emailErr) {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}

View File

@@ -0,0 +1,54 @@
package cmd
import (
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
// It initiates the device-based authentication process for Kilo AI services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = func(prompt string) (string, error) {
fmt.Print(prompt)
var value string
fmt.Scanln(&value)
return strings.TrimSpace(value), nil
}
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
if err != nil {
fmt.Printf("Kilo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Kilo authentication successful!")
}

View File

@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
log.Info("Authentication successful.")
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Errorf("Failed to get project list: %v", errProjects)
return
var activatedProjects []string
useGoogleOne := false
if trimmedProjectID == "" && promptFn != nil {
fmt.Println("\nSelect login mode:")
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
fmt.Println(" 2. Google One (personal account, auto-discover project)")
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
useGoogleOne = true
}
}
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Error("No project selected; aborting login.")
return
}
activatedProjects := make([]string, 0, len(projectSelections))
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
var projectErr *projectSelectionRequiredError
if errors.As(errSetup, &projectErr) {
log.Error("Failed to start user onboarding: A project ID is required.")
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Errorf("Failed to complete user setup: %v", errSetup)
if useGoogleOne {
log.Info("Google One mode: auto-discovering project...")
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidateID
autoProject := strings.TrimSpace(storage.ProjectID)
if autoProject == "" {
log.Error("Google One auto-discovery returned empty project ID")
return
}
log.Infof("Auto-discovered project: %s", autoProject)
activatedProjects = []string{autoProject}
} else {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Errorf("Failed to get project list: %v", errProjects)
return
}
// Skip duplicates
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Error("No project selected; aborting login.")
return
}
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
log.Error("Failed to start user onboarding: A project ID is required.")
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Errorf("Failed to complete user setup: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidateID
}
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
storage.Auto = false
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
return
}
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
if record.Metadata == nil {
record.Metadata = make(map[string]any)

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
var authErr *codex.AuthenticationError
if errors.As(err, &authErr) {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)

View File

@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
if err != nil {
var emailErr *sdkAuth.EmailRequiredError
if errors.As(err, &emailErr) {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}

View File

@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {

View File

@@ -97,6 +97,10 @@ type Config struct {
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
@@ -130,6 +134,15 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -301,6 +314,10 @@ type CloakConfig struct {
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
// CacheUserID controls whether Claude user_id values are cached per API key.
// When false, a fresh random user_id is generated for every request.
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
@@ -368,6 +385,9 @@ type CodexKey struct {
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
@@ -743,22 +763,24 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
return
}
// Inject default Kiro aliases if no user-configured kiro aliases exist
// Inject channel defaults when the channel is absent in user config.
// Presence is checked case-insensitively and includes explicit nil/empty markers.
if cfg.OAuthModelAlias == nil {
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
}
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
// Check case-insensitive too
found := false
hasChannel := func(channel string) bool {
for k := range cfg.OAuthModelAlias {
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
found = true
break
if strings.EqualFold(strings.TrimSpace(k), channel) {
return true
}
}
if !found {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
return false
}
if !hasChannel("kiro") {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
if !hasChannel("github-copilot") {
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
}
if len(cfg.OAuthModelAlias) == 0 {
@@ -767,7 +789,13 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
if channel == "" {
continue
}
// Preserve channels that were explicitly set to empty/nil they act
// as "disabled" markers so default injection won't re-add them (#222).
if len(aliases) == 0 {
out[channel] = nil
continue
}
seenAlias := make(map[string]struct{}, len(aliases))

View File

@@ -42,6 +42,21 @@ func defaultKiroAliases() []OAuthModelAlias {
}
}
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
// This keeps compatibility with clients (e.g. Claude Code) that use
// Anthropic-style model IDs like "claude-opus-4-6".
func defaultGitHubCopilotAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
}
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {

View File

@@ -107,6 +107,44 @@ func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) == 0 {
t.Fatal("expected default github-copilot aliases to be injected")
}
aliasSet := make(map[string]bool, len(copilotAliases))
for _, a := range copilotAliases {
aliasSet[a.Alias] = true
if !a.Fork {
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
expectedAliases := []string{
"claude-haiku-4-5",
"claude-opus-4-1",
"claude-opus-4-5",
"claude-opus-4-6",
"claude-sonnet-4-5",
"claude-sonnet-4-6",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default github-copilot alias %q to be present", expected)
}
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
// When user has configured kiro aliases, defaults should NOT be injected
cfg := &Config{
@@ -128,6 +166,88 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": {
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 1 {
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
}
if copilotAliases[0].Alias != "my-opus" {
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
// When user explicitly deletes kiro aliases (key exists with nil value),
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": nil, // explicitly deleted
"codex": {{Name: "gpt-5", Alias: "g5"}},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) != 0 {
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
}
// The key itself must still be present to prevent re-injection on next reload
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
}
// Other channels should be unaffected
if len(cfg.OAuthModelAlias["codex"]) != 1 {
t.Fatal("expected codex aliases to be preserved")
}
}
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": nil, // explicitly deleted
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 0 {
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
}
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
// Same as above but with empty slice instead of nil (PUT with empty body).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": {}, // explicitly set to empty
},
}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
}
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved")
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
// When OAuthModelAlias is nil, kiro defaults should still be injected
cfg := &Config{}

View File

@@ -20,6 +20,10 @@ type SDKConfig struct {
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`

View File

@@ -27,4 +27,7 @@ const (
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
Kiro = "kiro"
// Kilo represents the Kilo AI provider identifier.
Kilo = "kilo"
)

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -0,0 +1,21 @@
// Package registry provides model definitions for various AI service providers.
package registry
// GetKiloModels returns the Kilo model definitions
func GetKiloModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kilo/auto",
Object: "model",
Created: 1732752000,
OwnedBy: "kilo",
Type: "kilo",
DisplayName: "Kilo Auto",
Description: "Automatic model selection by Kilo",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -19,7 +19,9 @@ import (
// - codex
// - qwen
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - kiro
// - amazonq
@@ -43,10 +45,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetQwenModels()
case "iflow":
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "github-copilot":
return GetGitHubCopilotModels()
case "kiro":
return GetKiroModels()
case "kilo":
return GetKiloModels()
case "amazonq":
return GetAmazonQModels()
case "antigravity":
@@ -93,8 +99,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
}
for _, models := range allModels {
@@ -121,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {
now := int64(1732752000) // 2024-11-27
return []*ModelInfo{
gpt4oEntries := []struct {
ID string
DisplayName string
Description string
}{
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
}
models := []*ModelInfo{
{
ID: "gpt-4.1",
Object: "model",
@@ -133,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
}
for _, entry := range gpt4oEntries {
models = append(models, &ModelInfo{
ID: entry.ID,
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: entry.DisplayName,
Description: entry.Description,
ContextLength: 128000,
MaxCompletionTokens: 16384,
})
}
return append(models, []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
@@ -144,6 +181,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-mini",
@@ -156,6 +194,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-codex",
@@ -168,6 +207,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1",
@@ -180,6 +220,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex",
@@ -192,6 +233,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-mini",
@@ -204,6 +246,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-max",
@@ -216,6 +259,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2",
@@ -228,6 +272,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2-codex",
@@ -240,6 +285,20 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.3 Codex",
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
@@ -313,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4.6",
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
@@ -335,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -369,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}
}...)
}
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
@@ -400,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6",
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5",
Object: "model",
@@ -448,6 +542,87 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
// --- 第三方模型 (通过 Kiro 接入) ---
{
ID: "kiro-deepseek-3-2",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2",
Description: "DeepSeek 3.2 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1",
Description: "MiniMax M2.1 via Kiro",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next",
Description: "Qwen3 Coder Next via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-gpt-4o",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4o",
Description: "OpenAI GPT-4o via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4",
Description: "OpenAI GPT-4 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 8192,
},
{
ID: "kiro-gpt-4-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4 Turbo",
Description: "OpenAI GPT-4 Turbo via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-3-5-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-3.5 Turbo",
Description: "OpenAI GPT-3.5 Turbo via Kiro",
ContextLength: 16384,
MaxCompletionTokens: 4096,
},
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
{
ID: "kiro-claude-opus-4-6-agentic",
@@ -461,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6-agentic",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5-agentic",
Object: "model",

View File

@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-6",
Object: "model",
@@ -40,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 128000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771286400, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
Description: "Best combination of speed and intelligence",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -742,6 +825,20 @@ func GetOpenAIModels() []*ModelInfo {
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex-spark",
Object: "model",
Created: 1770912000,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.3",
DisplayName: "GPT 5.3 Codex Spark",
Description: "Ultra-fast coding model.",
ContextLength: 128000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
}
}
@@ -774,6 +871,19 @@ func GetQwenModels() []*ModelInfo {
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "coder-model",
Object: "model",
Created: 1771171200,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.5",
DisplayName: "Qwen 3.5 Plus",
Description: "efficient hybrid model with leading coding performance",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
@@ -806,19 +916,12 @@ func GetIFlowModels() []*ModelInfo {
Created int64
Thinking *ThinkingSupport
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
@@ -827,10 +930,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -864,11 +964,14 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}

View File

@@ -601,8 +601,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
registration.QuotaExceededClients[clientID] = new(time.Now())
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}

View File

@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
}
}(firstEvent)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.

View File

@@ -54,8 +54,78 @@ const (
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -232,7 +302,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -436,7 +506,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -645,7 +715,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -775,7 +845,6 @@ attemptLoop:
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
@@ -820,7 +889,7 @@ attemptLoop:
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
switch {
@@ -968,7 +1037,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode
@@ -1008,8 +1077,8 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return nil
}
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
@@ -1021,7 +1090,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
return nil
return fallbackAntigravityPrimaryModels()
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
@@ -1033,13 +1102,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
@@ -1051,19 +1120,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
@@ -1108,9 +1185,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return nil
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {

View File

@@ -0,0 +1,90 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
} else {
r.Header.Set("Accept", "application/json")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
if !contentBlock.Exists() {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil
return "auto", false, nil, false
}
cloakMode := auth.Attributes["cloak_mode"]
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
}
}
return cloakMode, strictMode, sensitiveWords
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
func injectFakeUserID(payload []byte) []byte {
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
}
return generateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
}
// Inject fake user ID
payload = injectFakeUserID(payload)
payload = injectFakeUserID(payload, apiKey, cacheUserID)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {

View File

@@ -2,9 +2,18 @@ package executor
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userID := gjson.GetBytes(body, "metadata.user_id").String()
model := gjson.GetBytes(body, "model").String()
userIDs = append(userIDs, userID)
requestModels = append(requestModels, model)
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
cacheEnabled := true
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "key-123",
BaseURL: server.URL,
Cloak: &config.CloakConfig{
CacheUserID: &cacheEnabled,
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
for _, model := range models {
t.Logf("Sending request for model: %s", model)
modelPayload, _ := sjson.SetBytes(payload, "model", model)
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: model,
Payload: modelPayload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute(%s) error: %v", model, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
for i := 0; i < 2; i++ {
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute call %d error: %v", i, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}

View File

@@ -28,8 +28,8 @@ import (
)
const (
codexClientVersion = "0.98.0"
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
codexClientVersion = "0.101.0"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
)
var dataTag = []byte("data:")
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)

File diff suppressed because it is too large Load Diff

View File

@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
if len(lastBody) > 0 {
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
@@ -899,8 +898,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
seconds, err := strconv.Atoi(matches[1])
if err == nil {
duration := time.Duration(seconds) * time.Second
return &duration, nil
return new(time.Duration(seconds) * time.Second), nil
}
}
}

View File

@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).

View File

@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -39,7 +39,8 @@ const (
copilotEditorVersion = "vscode/1.107.0"
copilotPluginVersion = "copilot-chat/0.35.0"
copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-edits"
copilotOpenAIIntent = "conversation-panel"
copilotGitHubAPIVer = "2025-04-01"
)
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
@@ -51,8 +52,9 @@ type GitHubCopilotExecutor struct {
// cachedAPIToken stores a cached Copilot API token with its expiry.
type cachedAPIToken struct {
token string
expiresAt time.Time
token string
apiEndpoint string
expiresAt time.Time
}
// NewGitHubCopilotExecutor constructs a new executor instance.
@@ -75,7 +77,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
if ctx == nil {
ctx = context.Background()
}
apiToken, errToken := e.ensureAPIToken(ctx, auth)
apiToken, _, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return errToken
}
@@ -101,7 +103,7 @@ func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxya
// Execute handles non-streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return resp, errToken
}
@@ -110,7 +112,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
@@ -123,6 +125,25 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
if err != nil {
return resp, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", false)
@@ -131,7 +152,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
url := baseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
@@ -139,7 +160,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
if hasVision {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
@@ -199,15 +220,20 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
}
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
converted := ""
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
}
// ExecuteStream handles streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return nil, errToken
}
@@ -216,7 +242,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from)
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
@@ -229,6 +255,25 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
if err != nil {
return nil, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", true)
@@ -241,7 +286,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
url := baseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -249,7 +294,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
if hasVision {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
@@ -296,7 +341,6 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
@@ -329,7 +373,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
var chunks []string
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
@@ -344,7 +393,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// CountTokens is not supported for GitHub Copilot.
@@ -376,22 +428,22 @@ func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.
}
// ensureAPIToken gets or refreshes the Copilot API token.
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, string, error) {
if auth == nil {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
// Get the GitHub access token
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
}
// Check for cached API token using thread-safe access
e.mu.RLock()
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
e.mu.RUnlock()
return cached.token, nil
return cached.token, cached.apiEndpoint, nil
}
e.mu.RUnlock()
@@ -399,7 +451,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
if err != nil {
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
return "", "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
}
// Use endpoint from token response, fall back to default
apiEndpoint := githubCopilotBaseURL
if apiToken.Endpoints.API != "" {
apiEndpoint = strings.TrimRight(apiToken.Endpoints.API, "/")
}
// Cache the token with thread-safe access
@@ -409,12 +467,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
}
e.mu.Lock()
e.cache[accessToken] = &cachedAPIToken{
token: apiToken.Token,
expiresAt: expiresAt,
token: apiToken.Token,
apiEndpoint: apiEndpoint,
expiresAt: expiresAt,
}
e.mu.Unlock()
return apiToken.Token, nil
return apiToken.Token, apiEndpoint, nil
}
// applyHeaders sets the required headers for GitHub Copilot API requests.
@@ -427,16 +486,17 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer)
r.Header.Set("X-Request-Id", uuid.NewString())
initiator := "user"
if len(body) > 0 {
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
arr := messages.Array()
if len(arr) > 0 {
lastRole := arr[len(arr)-1].Get("role").String()
if lastRole != "" && lastRole != "user" {
for _, msg := range messages.Array() {
role := msg.Get("role").String()
if role == "assistant" || role == "tool" {
initiator = "agent"
break
}
}
}
@@ -483,8 +543,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
return body
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
return sourceFormat.String() == "openai-response"
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
if sourceFormat.String() == "openai-response" {
return true
}
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
return strings.Contains(baseModel, "codex")
}
// flattenAssistantContent converts assistant message content from array format
@@ -504,6 +568,17 @@ func flattenAssistantContent(body []byte) []byte {
if !content.Exists() || !content.IsArray() {
continue
}
// Skip flattening if the content contains non-text blocks (tool_use, thinking, etc.)
hasNonText := false
for _, part := range content.Array() {
if t := part.Get("type").String(); t != "" && t != "text" {
hasNonText = true
break
}
}
if hasNonText {
continue
}
var textParts []string
for _, part := range content.Array() {
if part.Get("type").String() == "text" {
@@ -519,6 +594,644 @@ func flattenAssistantContent(body []byte) []byte {
return result
}
func normalizeGitHubCopilotChatTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
if tool.Get("type").String() != "function" {
continue
}
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if input.Exists() {
// If input is already a string or array, keep it as-is.
if input.Type == gjson.String || input.IsArray() {
return body
}
// Non-string/non-array input: stringify as fallback.
body, _ = sjson.SetBytes(body, "input", input.Raw)
return body
}
// Convert Claude messages format to OpenAI Responses API input array.
// This preserves the conversation structure (roles, tool calls, tool results)
// which is critical for multi-turn tool-use conversations.
inputArr := "[]"
// System messages → developer role
if system := gjson.GetBytes(body, "system"); system.Exists() {
var systemParts []string
if system.IsArray() {
for _, part := range system.Array() {
if txt := part.Get("text").String(); txt != "" {
systemParts = append(systemParts, txt)
}
}
} else if system.Type == gjson.String {
systemParts = append(systemParts, system.String())
}
if len(systemParts) > 0 {
msg := `{"type":"message","role":"developer","content":[]}`
for _, txt := range systemParts {
part := `{"type":"input_text","text":""}`
part, _ = sjson.Set(part, "text", txt)
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", msg)
}
}
// Messages → structured input items
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
role := msg.Get("role").String()
content := msg.Get("content")
if !content.Exists() {
continue
}
// Simple string content
if content.Type == gjson.String {
textType := "input_text"
if role == "assistant" {
textType = "output_text"
}
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
part, _ = sjson.Set(part, "text", content.String())
item, _ = sjson.SetRaw(item, "content.-1", part)
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
continue
}
if !content.IsArray() {
continue
}
// Array content: split into message parts vs tool items
var msgParts []string
for _, c := range content.Array() {
cType := c.Get("type").String()
switch cType {
case "text":
textType := "input_text"
if role == "assistant" {
textType = "output_text"
}
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
part, _ = sjson.Set(part, "text", c.Get("text").String())
msgParts = append(msgParts, part)
case "image":
source := c.Get("source")
if source.Exists() {
data := source.Get("data").String()
if data == "" {
data = source.Get("base64").String()
}
mediaType := source.Get("media_type").String()
if mediaType == "" {
mediaType = source.Get("mime_type").String()
}
if mediaType == "" {
mediaType = "application/octet-stream"
}
if data != "" {
part := `{"type":"input_image","image_url":""}`
part, _ = sjson.Set(part, "image_url", fmt.Sprintf("data:%s;base64,%s", mediaType, data))
msgParts = append(msgParts, part)
}
}
case "tool_use":
// Flush any accumulated message parts first
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
msgParts = nil
}
fc := `{"type":"function_call","call_id":"","name":"","arguments":""}`
fc, _ = sjson.Set(fc, "call_id", c.Get("id").String())
fc, _ = sjson.Set(fc, "name", c.Get("name").String())
if inputRaw := c.Get("input"); inputRaw.Exists() {
fc, _ = sjson.Set(fc, "arguments", inputRaw.Raw)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", fc)
case "tool_result":
// Flush any accumulated message parts first
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
msgParts = nil
}
fco := `{"type":"function_call_output","call_id":"","output":""}`
fco, _ = sjson.Set(fco, "call_id", c.Get("tool_use_id").String())
// Extract output text
resultContent := c.Get("content")
if resultContent.Type == gjson.String {
fco, _ = sjson.Set(fco, "output", resultContent.String())
} else if resultContent.IsArray() {
var resultParts []string
for _, rc := range resultContent.Array() {
if txt := rc.Get("text").String(); txt != "" {
resultParts = append(resultParts, txt)
}
}
fco, _ = sjson.Set(fco, "output", strings.Join(resultParts, "\n"))
} else if resultContent.Exists() {
fco, _ = sjson.Set(fco, "output", resultContent.String())
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", fco)
case "thinking":
// Skip thinking blocks - not part of the API input
}
}
// Flush remaining message parts
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
}
}
}
body, _ = sjson.SetRawBytes(body, "input", []byte(inputArr))
// Remove messages/system since we've converted them to input
body, _ = sjson.DeleteBytes(body, "messages")
body, _ = sjson.DeleteBytes(body, "system")
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
// Accept OpenAI format (type="function") and Claude format
// (no type field, but has top-level name + input_schema).
if toolType != "" && toolType != "function" {
continue
}
name := tool.Get("name").String()
if name == "" {
name = tool.Get("function.name").String()
}
if name == "" {
continue
}
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
if desc := tool.Get("description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
} else if desc = tool.Get("function.description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
}
if params := tool.Get("parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("function.parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("input_schema"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
}
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
default:
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
}
if toolChoice.Type == gjson.JSON {
choiceType := toolChoice.Get("type").String()
if choiceType == "function" {
name := toolChoice.Get("name").String()
if name == "" {
name = toolChoice.Get("function.name").String()
}
if name != "" {
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
return body
}
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func collectTextFromNode(node gjson.Result) string {
if !node.Exists() {
return ""
}
if node.Type == gjson.String {
return node.String()
}
if node.IsArray() {
var parts []string
for _, item := range node.Array() {
if item.Type == gjson.String {
if text := item.String(); text != "" {
parts = append(parts, text)
}
continue
}
if text := item.Get("text").String(); text != "" {
parts = append(parts, text)
continue
}
if nested := collectTextFromNode(item.Get("content")); nested != "" {
parts = append(parts, nested)
}
}
return strings.Join(parts, "\n")
}
if node.Type == gjson.JSON {
if text := node.Get("text").String(); text != "" {
return text
}
if nested := collectTextFromNode(node.Get("content")); nested != "" {
return nested
}
return node.Raw
}
return node.String()
}
type githubCopilotResponsesStreamToolState struct {
Index int
ID string
Name string
}
type githubCopilotResponsesStreamState struct {
MessageStarted bool
MessageStopSent bool
TextBlockStarted bool
TextBlockIndex int
NextContentIndex int
HasToolUse bool
ReasoningActive bool
ReasoningIndex int
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
hasToolUse := false
if output := root.Get("output"); output.Exists() && output.IsArray() {
for _, item := range output.Array() {
switch item.Get("type").String() {
case "reasoning":
var thinkingText string
if summary := item.Get("summary"); summary.Exists() && summary.IsArray() {
var parts []string
for _, part := range summary.Array() {
if txt := part.Get("text").String(); txt != "" {
parts = append(parts, txt)
}
}
thinkingText = strings.Join(parts, "")
}
if thinkingText == "" {
if content := item.Get("content"); content.Exists() && content.IsArray() {
var parts []string
for _, part := range content.Array() {
if txt := part.Get("text").String(); txt != "" {
parts = append(parts, txt)
}
}
thinkingText = strings.Join(parts, "")
}
}
if thinkingText != "" {
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() && content.IsArray() {
for _, part := range content.Array() {
if part.Get("type").String() != "output_text" {
continue
}
text := part.Get("text").String()
if text == "" {
continue
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
case "function_call":
hasToolUse = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolID := item.Get("call_id").String()
if toolID == "" {
toolID = item.Get("id").String()
}
toolUse, _ = sjson.Set(toolUse, "id", toolID)
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
argObj := gjson.Parse(args)
if argObj.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
}
}
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
}
}
}
inputTokens := root.Get("usage.input_tokens").Int()
outputTokens := root.Get("usage.output_tokens").Int()
cachedTokens := root.Get("usage.input_tokens_details.cached_tokens").Int()
if cachedTokens > 0 && inputTokens >= cachedTokens {
inputTokens -= cachedTokens
}
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
}
if hasToolUse {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else if sr := root.Get("stop_reason").String(); sr == "max_tokens" || sr == "stop" {
out, _ = sjson.Set(out, "stop_reason", sr)
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
return out
}
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
if *param == nil {
*param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1,
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
}
}
state := (*param).(*githubCopilotResponsesStreamState)
if !bytes.HasPrefix(line, dataTag) {
return nil
}
payload := bytes.TrimSpace(line[5:])
if bytes.Equal(payload, []byte("[DONE]")) {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4)
ensureMessageStart := func() {
if state.MessageStarted {
return
}
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
state.MessageStarted = true
}
startTextBlockIfNeeded := func() {
if state.TextBlockStarted {
return
}
if state.TextBlockIndex < 0 {
state.TextBlockIndex = state.NextContentIndex
state.NextContentIndex++
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
state.TextBlockStarted = true
}
stopTextBlockIfNeeded := func() {
if !state.TextBlockStarted {
return
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
state.TextBlockStarted = false
state.TextBlockIndex = -1
}
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
if itemID != "" {
if tool, ok := state.ItemIDToTool[itemID]; ok {
return tool
}
}
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
if itemID != "" {
state.ItemIDToTool[itemID] = tool
}
return tool
}
return nil
}
switch event {
case "response.created":
ensureMessageStart()
case "response.output_text.delta":
ensureMessageStart()
startTextBlockIfNeeded()
delta := gjson.GetBytes(payload, "delta").String()
if delta != "" {
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
}
case "response.reasoning_summary_part.added":
ensureMessageStart()
state.ReasoningActive = true
state.ReasoningIndex = state.NextContentIndex
state.NextContentIndex++
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
case "response.reasoning_summary_text.delta":
if state.ReasoningActive {
delta := gjson.GetBytes(payload, "delta").String()
if delta != "" {
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
}
}
case "response.reasoning_summary_part.done":
if state.ReasoningActive {
thinkingStop := `{"type":"content_block_stop","index":0}`
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
state.ReasoningActive = false
}
case "response.output_item.added":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
ensureMessageStart()
stopTextBlockIfNeeded()
state.HasToolUse = true
tool := &githubCopilotResponsesStreamToolState{
Index: state.NextContentIndex,
ID: gjson.GetBytes(payload, "item.call_id").String(),
Name: gjson.GetBytes(payload, "item.name").String(),
}
if tool.ID == "" {
tool.ID = gjson.GetBytes(payload, "item.id").String()
}
state.NextContentIndex++
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
state.OutputIndexToTool[outputIndex] = tool
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
state.ItemIDToTool[itemID] = tool
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
case "response.output_item.delta":
item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" {
break
}
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
partial := gjson.GetBytes(payload, "delta").String()
if partial == "" {
partial = item.Get("arguments").String()
}
if partial == "" {
break
}
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.function_call_arguments.delta":
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
itemID := gjson.GetBytes(payload, "item_id").String()
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
tool := resolveTool(itemID, outputIndex)
if tool == nil {
break
}
partial := gjson.GetBytes(payload, "delta").String()
if partial == "" {
break
}
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
case "response.completed":
ensureMessageStart()
stopTextBlockIfNeeded()
if !state.MessageStopSent {
stopReason := "end_turn"
if state.HasToolUse {
stopReason = "tool_use"
} else if sr := gjson.GetBytes(payload, "response.stop_reason").String(); sr == "max_tokens" || sr == "stop" {
stopReason = sr
}
inputTokens := gjson.GetBytes(payload, "response.usage.input_tokens").Int()
outputTokens := gjson.GetBytes(payload, "response.usage.output_tokens").Int()
cachedTokens := gjson.GetBytes(payload, "response.usage.input_tokens_details.cached_tokens").Int()
if cachedTokens > 0 && inputTokens >= cachedTokens {
inputTokens -= cachedTokens
}
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", inputTokens)
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
}
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true
}
}
return results
}
// isHTTPSuccess checks if the status code indicates success (2xx).
func isHTTPSuccess(statusCode int) bool {
return statusCode >= 200 && statusCode < 300

View File

@@ -1,8 +1,11 @@
package executor
import (
"net/http"
"strings"
"testing"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -52,3 +55,279 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
})
}
}
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
t.Fatal("expected openai-response source to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
t.Fatal("expected codex model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
}
}
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
got := normalizeGitHubCopilotChatTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
}
}
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
got := normalizeGitHubCopilotChatTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
t.Parallel()
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if !in.IsArray() {
t.Fatalf("input type = %v, want array", in.Type)
}
raw := in.Raw
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
t.Fatalf("input = %s, want structured array with all texts", raw)
}
if gjson.GetBytes(got, "messages").Exists() {
t.Fatal("messages should be removed after conversion")
}
if gjson.GetBytes(got, "system").Exists() {
t.Fatal("system should be removed after conversion")
}
}
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
t.Parallel()
body := []byte(`{"input":{"foo":"bar"}}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "foo") {
t.Fatalf("input = %q, want stringified object", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("name").String() != "sum" {
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be preserved")
}
}
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 2 {
t.Fatalf("tools len = %d, want 2", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
}
if tools[0].Get("name").String() != "Bash" {
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
}
if tools[0].Get("description").String() != "Run commands" {
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be set from input_schema")
}
if tools[0].Get("parameters.properties.command").Exists() != true {
t.Fatal("expected parameters.properties.command to exist")
}
if tools[1].Get("name").String() != "Read" {
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
}
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function"}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
}
}
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
t.Parallel()
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}
// --- Tests for X-Initiator detection logic (Problem L) ---
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Claude Code typical flow: last message is user (tool result), but has assistant in history
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
}
}
// --- Tests for vision detection (Problem P) ---
func TestDetectVisionContent_WithImageURL(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected vision content to be detected")
}
}
func TestDetectVisionContent_WithImageType(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected image type to be detected")
}
}
func TestDetectVisionContent_NoVision(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content")
}
}
func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Parallel()
// After Responses API normalization, messages is removed — detection should return false
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content when messages field is absent")
}
}

View File

@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -0,0 +1,460 @@
package executor
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// KiloExecutor handles requests to Kilo API.
type KiloExecutor struct {
cfg *config.Config
}
// NewKiloExecutor creates a new Kilo executor instance.
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
return &KiloExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *KiloExecutor) Identifier() string { return "kilo" }
// PrepareRequest prepares the HTTP request before execution.
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, _ := kiloCredentials(auth)
if strings.TrimSpace(accessToken) == "" {
return fmt.Errorf("kilo: missing access token")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("kilo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer httpResp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer httpResp.Body.Close()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh validates the Kilo token.
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("missing auth")
}
return auth, nil
}
// CountTokens returns the token count for the given request.
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
}
// kiloCredentials extracts access token and other info from auth.
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
if auth == nil {
return "", ""
}
// Prefer kilocode specific keys, then fall back to generic keys.
// Check metadata first, then attributes.
if auth.Metadata != nil {
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
accessToken = token
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
accessToken = token
}
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
orgID = org
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
orgID = org
}
}
if accessToken == "" && auth.Attributes != nil {
if token := auth.Attributes["kilocodeToken"]; token != "" {
accessToken = token
} else if token := auth.Attributes["access_token"]; token != "" {
accessToken = token
}
}
if orgID == "" && auth.Attributes != nil {
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
orgID = org
} else if org := auth.Attributes["organization_id"]; org != "" {
orgID = org
}
}
return accessToken, orgID
}
// FetchKiloModels fetches models from Kilo API.
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
return registry.GetKiloModels()
}
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
if err != nil {
log.Warnf("kilo: failed to create model fetch request: %v", err)
return registry.GetKiloModels()
}
req.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
req.Header.Set("X-Kilocode-OrganizationID", orgID)
}
req.Header.Set("User-Agent", "cli-proxy-kilo")
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Warnf("kilo: fetch models canceled: %v", err)
} else {
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
}
return registry.GetKiloModels()
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Warnf("kilo: failed to read models response: %v", err)
return registry.GetKiloModels()
}
if resp.StatusCode != http.StatusOK {
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
return registry.GetKiloModels()
}
result := gjson.GetBytes(body, "data")
if !result.Exists() {
// Try root if data field is missing
result = gjson.ParseBytes(body)
if !result.IsArray() {
log.Debugf("kilo: response body: %s", string(body))
log.Warn("kilo: invalid API response format (expected array or data field with array)")
return registry.GetKiloModels()
}
}
var dynamicModels []*registry.ModelInfo
now := time.Now().Unix()
count := 0
totalCount := 0
result.ForEach(func(key, value gjson.Result) bool {
totalCount++
id := value.Get("id").String()
pIdxResult := value.Get("preferredIndex")
preferredIndex := pIdxResult.Int()
// Filter models where preferredIndex > 0 (Kilo-curated models)
if preferredIndex <= 0 {
return true
}
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
if !isFree {
// Check pricing as fallback
promptPricing := value.Get("pricing.prompt").String()
if promptPricing == "0" || promptPricing == "0.0" {
isFree = true
}
}
if !isFree {
log.Debugf("kilo: skipping curated paid model: %s", id)
return true
}
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
dynamicModels = append(dynamicModels, &registry.ModelInfo{
ID: id,
DisplayName: value.Get("name").String(),
ContextLength: int(value.Get("context_length").Int()),
OwnedBy: "kilo",
Type: "kilo",
Object: "model",
Created: now,
})
count++
return true
})
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
if count == 0 && totalCount > 0 {
log.Warn("kilo: no curated free models found (check API response fields)")
}
staticModels := registry.GetKiloModels()
// Always include kilo/auto (first static model)
allModels := append(staticModels[:1], dynamicModels...)
return allModels
}

View File

@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request to Kimi.
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for Kimi requests.

File diff suppressed because it is too large Load Diff

View File

@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -22,11 +23,151 @@ import (
)
const (
qwenUserAgent = "google-api-nodejs-client/9.15.1"
qwenXGoogAPIClient = "gl-node/22.17.0"
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{
"insufficient_quota": {},
"quota_exceeded": {},
}
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
// Qwen has a limit of 60 requests per minute per account.
var qwenRateLimiter = struct {
sync.Mutex
requests map[string][]time.Time // authID -> request timestamps
}{
requests: make(map[string][]time.Time),
}
// redactAuthID returns a redacted version of the auth ID for safe logging.
// Keeps a small prefix/suffix to allow correlation across events.
func redactAuthID(id string) string {
if id == "" {
return ""
}
if len(id) <= 8 {
return id
}
return id[:4] + "..." + id[len(id)-4:]
}
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
func checkQwenRateLimit(authID string) error {
if authID == "" {
// Empty authID should not bypass rate limiting in production
// Use debug level to avoid log spam for certain auth flows
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
return nil
}
now := time.Now()
windowStart := now.Add(-qwenRateLimitWindow)
qwenRateLimiter.Lock()
defer qwenRateLimiter.Unlock()
// Get and filter timestamps within the window
timestamps := qwenRateLimiter.requests[authID]
var validTimestamps []time.Time
for _, ts := range timestamps {
if ts.After(windowStart) {
validTimestamps = append(validTimestamps, ts)
}
}
// Always prune expired entries to prevent memory leak
// Delete empty entries, otherwise update with pruned slice
if len(validTimestamps) == 0 {
delete(qwenRateLimiter.requests, authID)
}
// Check if rate limit exceeded
if len(validTimestamps) >= qwenRateLimitPerMin {
// Calculate when the oldest request will expire
oldestInWindow := validTimestamps[0]
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
retryAfterSec := int(retryAfter.Seconds())
return statusErr{
code: http.StatusTooManyRequests,
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
retryAfter: &retryAfter,
}
}
// Record this request and update the map with pruned timestamps
validTimestamps = append(validTimestamps, now)
qwenRateLimiter.requests[authID] = validTimestamps
return nil
}
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
func isQwenQuotaError(body []byte) bool {
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
// Primary check: exact match on error.code or error.type (most reliable)
if _, ok := qwenQuotaCodes[code]; ok {
return true
}
if _, ok := qwenQuotaCodes[errType]; ok {
return true
}
// Fallback: check message only if code/type don't match (less reliable)
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
strings.Contains(msg, "free allocated quota exceeded") {
return true
}
return false
}
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
// Returns the appropriate status code and retryAfter duration for statusErr.
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
errCode = httpCode
// Only check quota errors for expected status codes to avoid false positives
// Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
// Qwen's daily quota resets at 00:00 Beijing time.
func timeUntilNextDay() time.Duration {
now := time.Now()
nowLocal := now.In(qwenBeijingLoc)
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
return tomorrow.Sub(now)
}
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
// If access token is unavailable, it falls back to legacy via ClientAdapter.
type QwenExecutor struct {
@@ -69,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -104,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -137,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -152,14 +305,25 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -202,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -230,15 +393,16 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -270,7 +434,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -344,8 +508,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header.Set("X-Stainless-Runtime", "node")
if stream {
r.Header.Set("Accept", "text/event-stream")
return

View File

@@ -0,0 +1,89 @@
package executor
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
type userIDCacheEntry struct {
value string
expire time.Time
}
var (
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu sync.RWMutex
userIDCacheCleanupOnce sync.Once
)
const (
userIDTTL = time.Hour
userIDCacheCleanupPeriod = 15 * time.Minute
)
func startUserIDCacheCleanup() {
go func() {
ticker := time.NewTicker(userIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredUserIDs()
}
}()
}
func purgeExpiredUserIDs() {
now := time.Now()
userIDCacheMu.Lock()
for key, entry := range userIDCache {
if !entry.expire.After(now) {
delete(userIDCache, key)
}
}
userIDCacheMu.Unlock()
}
func userIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
key := userIDCacheKey(apiKey)
now := time.Now()
userIDCacheMu.RLock()
entry, ok := userIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
userIDCacheMu.RUnlock()
if valid {
userIDCacheMu.Lock()
entry = userIDCache[key]
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}
userIDCacheMu.Unlock()
}
newID := generateFakeUserID()
userIDCacheMu.Lock()
entry, ok = userIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
entry.value = newID
}
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}

View File

@@ -0,0 +1,86 @@
package executor
import (
"testing"
"time"
)
func resetUserIDCache() {
userIDCacheMu.Lock()
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu.Unlock()
}
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
}
if first != second {
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
}
}
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: expiredID,
expire: time.Now().Add(-time.Minute),
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
if newID == "" {
t.Fatal("expected regenerated user_id to be non-empty")
}
}
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
}
}
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: id,
expire: soon.Add(2 * time.Second),
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}
userIDCacheMu.RLock()
entry := userIDCache[cacheKey]
userIDCacheMu.RUnlock()
if entry.expire.Sub(soon) < 30*time.Minute {
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
}
}

View File

@@ -223,16 +223,71 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
if len(frResults) == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
partJSON := `{}`
@@ -244,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)

View File

@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mime_type").String() != "image/png" {
t.Error("mime_type mismatch")
if inlineData.Get("mimeType").String() != "image/png" {
t.Error("mimeType mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
@@ -661,6 +661,508 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
// Bug repro: tool_result with no content field produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
// Verify the functionResponse has a valid result value
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
if !fr.Exists() {
t.Error("functionResponse.response.result should exist")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
// Bug repro: tool_result with null content produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456",
"content": null
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
// tool_result with array content containing text + image should place
// image data inside functionResponse.parts as inlineData, not as a
// sibling part in the outer content (to avoid base64 context bloat).
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-123-456",
"content": [
{
"type": "text",
"text": "File content here"
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
// Image should be inside functionResponse.parts, not as outer sibling part
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Text content should be in response.result
resultText := funcResp.Get("response.result.text").String()
if resultText != "File content here" {
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
// Image should NOT be in outer parts (only functionResponse part should exist)
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
// tool_result with single image object as content should place
// image data inside functionResponse.parts, not as outer sibling part.
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-789-012",
"content": {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRgABAQ=="
}
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// response.result should be empty (image only)
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/jpeg" {
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
}
// Image should NOT be in outer parts
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
// tool_result with array content: 2 text items + 2 images
// All images go into functionResponse.parts, texts into response.result array
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Multi-001",
"content": [
{"type": "text", "text": "First text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
},
{"type": "text", "text": "Second text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Multiple text items => response.result is an array
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// Both images should be in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
}
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
}
// Only 1 outer part (the functionResponse itself)
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
// tool_result with only images (no text) — response.result should be empty string
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "ImgOnly-001",
"content": [
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// No text => response.result should be empty string
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
}
// Both images in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("first image mimeType mismatch")
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
t.Error("second image mimeType mismatch")
}
// Only 1 outer part
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
// image with source.type != "base64" should be treated as non-image (falls through)
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NotB64-001",
"content": [
{"type": "text", "text": "some output"},
{
"type": "image",
"source": {"type": "url", "url": "https://example.com/img.png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Non-base64 image is treated as non-image, so it goes into the filtered results
// along with the text item. Since there are 2 non-image items, result is array.
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// No functionResponse.parts (no base64 images collected)
if funcResp.Get("parts").Exists() {
t.Error("functionResponse.parts should NOT exist when no base64 images")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
// image with source.type=base64 but missing data field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoData-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image (type check passes),
// but data field is missing => inlineData has mimeType but no data
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("mimeType should still be set")
}
if imgParts[0].Get("inlineData.data").Exists() {
t.Error("data should not exist when source.data is missing")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
// image with source.type=base64 but missing media_type field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoMime-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "data": "AAAA"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image,
// but media_type is missing => inlineData has data but no mimeType
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").Exists() {
t.Error("mimeType should not exist when media_type is missing")
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Error("data should still be set")
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{

View File

@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
}
}
}
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
// When functionResponse contains a "parts" field with inlineData (from Claude
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
// so extra fields like "parts" survive the pipeline.
input := `{
"model": "claude-opus-4-6-thinking",
"request": {
"contents": [
{
"role": "model",
"parts": [
{
"functionCall": {"name": "screenshot", "args": {}}
}
]
},
{
"role": "function",
"parts": [
{
"functionResponse": {
"id": "tool-001",
"name": "screenshot",
"response": {"result": "Screenshot taken"},
"parts": [
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
]
}
}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
// Find the function response content (role=function)
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// The functionResponse should be preserved with its parts field
funcResp := funcContent.Get("parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist in output")
}
// Verify the parts field with inlineData is preserved
inlineParts := funcResp.Get("parts").Array()
if len(inlineParts) != 1 {
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
}
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
}
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
}
// Verify response.result is also preserved
if funcResp.Get("response.result").String() != "Screenshot taken" {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}

View File

@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++

View File

@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
}
return true
})

View File

@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage {
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])

View File

@@ -22,8 +22,9 @@ var (
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
}
}
return []string{output}

View File

@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
// Files are not specified in examples; skip for now
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}

View File

@@ -20,10 +20,12 @@ var (
// ConvertCliToOpenAIParams holds parameters for response conversion.
type ConvertCliToOpenAIParams struct {
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
HasReceivedArgumentsDelta bool
HasToolCallAnnounced bool
}
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
HasReceivedArgumentsDelta: false,
HasToolCallAnnounced: false,
}
}
@@ -90,6 +94,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
@@ -115,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.done" {
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if itemResult.Exists() {
if itemResult.Get("type").String() != "function_call" {
return []string{}
}
// set the index
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened
name := itemResult.Get("name").String()
// Build reverse map on demand from original request tools
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
// Increment index for this new function call item.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
}
@@ -205,6 +270,9 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}

View File

@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
// Delete the user field as it is not supported by the Codex upstream.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
return rawJSON
}
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
// for Codex upstream compatibility.
//
// Codex /responses currently rejects context_management with:
// {"detail":"Unsupported parameter: context_management"}.
//
// Compatibility strategy:
// 1) Remove context_management before forwarding to Codex upstream.
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
return rawJSON
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
return rawJSON
}
// convertSystemRoleToDeveloper traverses the input array and converts any message items
// with role "system" to role "developer". This is necessary because Codex API does not
// accept "system" role in the input array.

View File

@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
}
}
func TestContextManagementCompactionCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"context_management": [
{
"type": "compaction",
"compact_threshold": 12000
}
],
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "context_management").Exists() {
t.Fatalf("context_management should be removed for Codex compatibility")
}
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"truncation": "disabled",
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}

View File

@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())

View File

@@ -243,13 +243,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
// Process messages and build history
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
// Build content with system prompt (only on first turn to avoid re-injection)
// Build content with system prompt.
// Keep thinking tags on subsequent turns so multi-turn Claude sessions
// continue to emit reasoning events.
if currentUserMsg != nil {
effectiveSystemPrompt := systemPrompt
if len(history) > 0 {
effectiveSystemPrompt = "" // Don't re-inject on subsequent turns
}
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, effectiveSystemPrompt, currentToolResults)
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
// Deduplicate currentToolResults
currentToolResults = deduplicateToolResults(currentToolResults)
@@ -475,6 +473,15 @@ func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
}
}
// Check model name directly for thinking hints.
// This enables thinking variants even when clients don't send explicit thinking fields.
model := strings.TrimSpace(gjson.GetBytes(body, "model").String())
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
return true
}
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
return false
}
@@ -586,6 +593,17 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
// Merge adjacent messages with the same role
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
// FIX: Kiro API requires history to start with a user message.
// Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
// which is valid for the Claude API but causes "Improperly formed request" on Kiro.
// Prepend a placeholder user message so the history alternation is correct.
if len(messagesArray) > 0 && messagesArray[0].Get("role").String() == "assistant" {
placeholder := `{"role":"user","content":"."}`
messagesArray = append([]gjson.Result{gjson.Parse(placeholder)}, messagesArray...)
log.Infof("kiro: messages started with assistant role, prepended placeholder user message for Kiro API compatibility")
}
for i, msg := range messagesArray {
role := msg.Get("role").String()
isLastMessage := i == len(messagesArray)-1

View File

@@ -183,4 +183,124 @@ func PendingTagSuffix(buffer, tag string) int {
}
}
return 0
}
}
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
// (server_tool_use + web_search_tool_result) without text summary or message termination.
// These events trigger Claude Code's search indicator UI.
// The caller is responsible for sending message_start before and message_delta/stop after.
func GenerateSearchIndicatorEvents(
query string,
toolUseID string,
searchResults *WebSearchResults,
startIndex int,
) [][]byte {
events := make([][]byte, 0, 5)
// 1. content_block_start (server_tool_use)
event1 := map[string]interface{}{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"id": toolUseID,
"type": "server_tool_use",
"name": "web_search",
"input": map[string]interface{}{},
},
}
data1, _ := json.Marshal(event1)
events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n"))
// 2. content_block_delta (input_json_delta)
inputJSON, _ := json.Marshal(map[string]string{"query": query})
event2 := map[string]interface{}{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
}
data2, _ := json.Marshal(event2)
events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n"))
// 3. content_block_stop (server_tool_use)
event3 := map[string]interface{}{
"type": "content_block_stop",
"index": startIndex,
}
data3, _ := json.Marshal(event3)
events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n"))
// 4. content_block_start (web_search_tool_result)
searchContent := make([]map[string]interface{}, 0)
if searchResults != nil {
for _, r := range searchResults.Results {
snippet := ""
if r.Snippet != nil {
snippet = *r.Snippet
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": r.Title,
"url": r.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
event4 := map[string]interface{}{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"tool_use_id": toolUseID,
"content": searchContent,
},
}
data4, _ := json.Marshal(event4)
events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n"))
// 5. content_block_stop (web_search_tool_result)
event5 := map[string]interface{}{
"type": "content_block_stop",
"index": startIndex + 1,
}
data5, _ := json.Marshal(event5)
events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n"))
return events
}
// BuildFallbackTextEvents generates SSE events for a fallback text response
// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
// functions to align with streamToChannel patterns.
// Returns raw SSE byte slices ready to be sent to the client channel.
func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
summary := FormatSearchContextPrompt(query, results)
outputTokens := len(summary) / 4
if len(summary) > 0 && outputTokens == 0 {
outputTokens = 1
}
var events [][]byte
// content_block_start (text)
events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
// content_block_delta (text_delta)
events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
// content_block_stop
events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
// message_delta with end_turn
events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
OutputTokens: int64(outputTokens),
}))
// message_stop
events = append(events, BuildClaudeMessageStopOnlyEvent())
return events
}

View File

@@ -0,0 +1,350 @@
package claude
import (
"encoding/json"
"strings"
log "github.com/sirupsen/logrus"
)
// sseEvent represents a Server-Sent Event
type sseEvent struct {
Event string
Data interface{}
}
// ToSSEString converts the event to SSE wire format
func (e *sseEvent) ToSSEString() string {
dataBytes, _ := json.Marshal(e.Data)
return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
}
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
// It also suppresses duplicate message_start events (returns shouldForward=false).
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
//
// The data parameter is a single SSE "data:" line payload (JSON).
// Returns: adjusted data, shouldForward (false = skip this event).
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
if len(data) == 0 {
return data, true
}
// Quick check: parse the JSON
var event map[string]interface{}
if err := json.Unmarshal(data, &event); err != nil {
// Not valid JSON, pass through
return data, true
}
eventType, _ := event["type"].(string)
// Suppress duplicate message_start events
if eventType == "message_start" {
return data, false
}
// Adjust index for content_block events
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + offset
adjusted, err := json.Marshal(event)
if err != nil {
return data, true
}
return adjusted, true
}
}
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
return data, true
}
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
// and adjusts content block indices. Suppresses duplicate message_start events.
// Returns the adjusted chunk and whether it should be forwarded.
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
chunkStr := string(chunk)
// Fast path: if no "data:" prefix, pass through
if !strings.Contains(chunkStr, "data: ") {
return chunk, true
}
var result strings.Builder
hasContent := false
lines := strings.Split(chunkStr, "\n")
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
result.WriteString(line + "\n")
hasContent = true
continue
}
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
if !shouldForward {
// Skip this event and its preceding "event:" line
// Also skip the trailing empty line
continue
}
result.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
dataPayload = strings.TrimSpace(dataPayload)
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
// Skip both the event: and data: lines
i++ // skip the data: line too
continue
}
}
}
result.WriteString(line + "\n")
hasContent = true
} else {
result.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if !hasContent {
return nil, false
}
return []byte(result.String()), true
}
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
type BufferedStreamResult struct {
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
StopReason string
// WebSearchQuery is the extracted query if the model requested another web_search
WebSearchQuery string
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
WebSearchToolUseId string
// HasWebSearchToolUse indicates whether the model requested web_search
HasWebSearchToolUse bool
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
WebSearchToolUseIndex int
}
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
// This is used in the search loop to determine if the model wants another search round.
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
// Track tool use state across chunks
var currentToolName string
var currentToolIndex int = -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" || dataPayload == "" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
continue
}
eventType, _ := event["type"].(string)
switch eventType {
case "message_delta":
// Extract stop_reason from message_delta
if delta, ok := event["delta"].(map[string]interface{}); ok {
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
result.StopReason = sr
}
}
case "content_block_start":
// Detect tool_use content blocks
if cb, ok := event["content_block"].(map[string]interface{}); ok {
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
if name, ok := cb["name"].(string); ok {
currentToolName = strings.ToLower(name)
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
// Capture tool use ID for toolResults handshake
if id, ok := cb["id"].(string); ok {
result.WebSearchToolUseId = id
}
toolInputBuilder.Reset()
}
}
}
case "content_block_delta":
// Accumulate tool input JSON
if currentToolName != "" {
if delta, ok := event["delta"].(map[string]interface{}); ok {
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partial)
}
}
}
}
case "content_block_stop":
// Finalize tool use detection
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
// Extract query from accumulated input JSON
inputJSON := toolInputBuilder.String()
var input map[string]string
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
if q, ok := input["query"]; ok {
result.WebSearchQuery = q
}
}
log.Debugf("kiro/websearch: detected web_search tool_use")
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
// when the proxy is handling the search loop internally.
// Also suppresses message_start and message_delta/message_stop events since those
// are managed by the outer handleWebSearchStream.
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
var filtered [][]byte
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
var resultBuilder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
// Skip [DONE] — the outer loop manages stream termination
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
resultBuilder.WriteString(line + "\n")
hasContent = true
continue
}
eventType, _ := event["type"].(string)
// Skip message_start (outer loop sends its own)
if eventType == "message_start" {
continue
}
// Skip message_delta and message_stop (outer loop manages these)
if eventType == "message_delta" || eventType == "message_stop" {
continue
}
// Check if this event belongs to the web_search tool_use block
if wsToolIndex >= 0 {
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
// Skip events for the web_search tool_use block
continue
}
}
// Apply index offset for remaining events
if indexOffset > 0 {
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
adjusted, err := json.Marshal(event)
if err == nil {
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
nextData := strings.TrimPrefix(lines[i+1], "data: ")
nextData = strings.TrimSpace(nextData)
var nextEvent map[string]interface{}
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
nextType, _ := nextEvent["type"].(string)
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
i++ // skip the data line
continue
}
if wsToolIndex >= 0 {
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
i++ // skip the data line
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else {
resultBuilder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if hasContent {
filtered = append(filtered, []byte(resultBuilder.String()))
}
}
return filtered
}

View File

@@ -1,11 +1,14 @@
// Package claude provides web search functionality for Kiro translator.
// This file implements detection and MCP request/response types for web search.
// This file implements detection, MCP request/response types, and pure data
// transformation utilities for web search. SSE event generation, stream analysis,
// and HTTP I/O logic reside in the executor package (kiro_executor.go).
package claude
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -14,6 +17,26 @@ import (
"github.com/tidwall/sjson"
)
// cachedToolDescription stores the dynamically-fetched web_search tool description.
// Written by the executor via SetWebSearchDescription, read by the translator
// when building the remote_web_search tool for Kiro API requests.
var cachedToolDescription atomic.Value // stores string
// GetWebSearchDescription returns the cached web_search tool description,
// or empty string if not yet fetched. Lock-free via atomic.Value.
func GetWebSearchDescription() string {
if v := cachedToolDescription.Load(); v != nil {
return v.(string)
}
return ""
}
// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
// Called by the executor after fetching from MCP tools/list.
func SetWebSearchDescription(desc string) {
cachedToolDescription.Store(desc)
}
// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
type McpRequest struct {
ID string `json:"id"`
@@ -191,36 +214,11 @@ func CreateMcpRequest(query string) (string, *McpRequest) {
return toolUseID, request
}
// GenerateMessageID generates a Claude-style message ID
func GenerateMessageID() string {
return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24]
}
// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
func GenerateToolUseID() string {
return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
}
// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools).
// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays.
func ContainsWebSearchTool(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
for _, tool := range tools.Array() {
name := strings.ToLower(tool.Get("name").String())
toolType := strings.ToLower(tool.Get("type").String())
if isWebSearchTool(name, toolType) {
return true
}
}
return false
}
// ReplaceWebSearchToolDescription replaces the web_search tool description with
// a minimal version that allows re-search without the restrictive "do not search
// non-coding topics" instruction from the original Kiro tools/list response.
@@ -275,48 +273,6 @@ func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
return result, nil
}
// StripWebSearchTool removes web_search tool entries from the request's tools array.
// If the tools array becomes empty after removal, it is removed entirely.
func StripWebSearchTool(body []byte) ([]byte, error) {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return body, nil
}
var filtered []json.RawMessage
for _, tool := range tools.Array() {
name := strings.ToLower(tool.Get("name").String())
toolType := strings.ToLower(tool.Get("type").String())
if !isWebSearchTool(name, toolType) {
filtered = append(filtered, json.RawMessage(tool.Raw))
}
}
var result []byte
var err error
if len(filtered) == 0 {
// Remove tools array entirely
result, err = sjson.DeleteBytes(body, "tools")
if err != nil {
return body, fmt.Errorf("failed to delete tools: %w", err)
}
} else {
// Replace with filtered array
filteredJSON, marshalErr := json.Marshal(filtered)
if marshalErr != nil {
return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr)
}
result, err = sjson.SetRawBytes(body, "tools", filteredJSON)
if err != nil {
return body, fmt.Errorf("failed to set filtered tools: %w", err)
}
}
return result, nil
}
// FormatSearchContextPrompt formats search results as a structured text block
// for injection into the system prompt.
func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
@@ -365,7 +321,7 @@ func FormatToolResultText(results *WebSearchResults) string {
//
// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available.
// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(claudePayload, &payload); err != nil {
@@ -432,8 +388,8 @@ Do NOT apologize for bad results without first attempting a re-search.
return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err)
}
log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, query=%s, messages=%d)",
toolUseId, query, len(messages))
log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)",
toolUseId, len(messages))
return result, nil
}
@@ -512,658 +468,28 @@ type SearchIndicator struct {
Results *WebSearchResults
}
// ══════════════════════════════════════════════════════════════════════════════
// SSE Event Generation
// ══════════════════════════════════════════════════════════════════════════════
// SseEvent represents a Server-Sent Event
type SseEvent struct {
Event string
Data interface{}
// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
func BuildMcpEndpoint(region string) string {
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
}
// ToSSEString converts the event to SSE wire format
func (e *SseEvent) ToSSEString() string {
dataBytes, _ := json.Marshal(e.Data)
return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes))
}
// GenerateWebSearchEvents generates the 11-event SSE sequence for web search.
// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json),
// content_block_stop, content_block_start(web_search_tool_result), content_block_stop,
// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop
func GenerateWebSearchEvents(
model string,
query string,
toolUseID string,
searchResults *WebSearchResults,
inputTokens int,
) []SseEvent {
events := make([]SseEvent, 0, 15)
messageID := GenerateMessageID()
// 1. message_start
events = append(events, SseEvent{
Event: "message_start",
Data: map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": messageID,
"type": "message",
"role": "assistant",
"model": model,
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
},
})
// 2. content_block_start (server_tool_use)
events = append(events, SseEvent{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": 0,
"content_block": map[string]interface{}{
"id": toolUseID,
"type": "server_tool_use",
"name": "web_search",
"input": map[string]interface{}{},
},
},
})
// 3. content_block_delta (input_json_delta)
inputJSON, _ := json.Marshal(map[string]string{"query": query})
events = append(events, SseEvent{
Event: "content_block_delta",
Data: map[string]interface{}{
"type": "content_block_delta",
"index": 0,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
},
})
// 4. content_block_stop (server_tool_use)
events = append(events, SseEvent{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": 0,
},
})
// 5. content_block_start (web_search_tool_result)
searchContent := make([]map[string]interface{}, 0)
if searchResults != nil {
for _, r := range searchResults.Results {
snippet := ""
if r.Snippet != nil {
snippet = *r.Snippet
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": r.Title,
"url": r.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
events = append(events, SseEvent{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"tool_use_id": toolUseID,
"content": searchContent,
},
},
})
// 6. content_block_stop (web_search_tool_result)
events = append(events, SseEvent{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": 1,
},
})
// 7. content_block_start (text)
events = append(events, SseEvent{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": 2,
"content_block": map[string]interface{}{
"type": "text",
"text": "",
},
},
})
// 8. content_block_delta (text_delta) - generate search summary
summary := generateSearchSummary(query, searchResults)
// Split text into chunks for streaming effect
chunkSize := 100
runes := []rune(summary)
for i := 0; i < len(runes); i += chunkSize {
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
chunk := string(runes[i:end])
events = append(events, SseEvent{
Event: "content_block_delta",
Data: map[string]interface{}{
"type": "content_block_delta",
"index": 2,
"delta": map[string]interface{}{
"type": "text_delta",
"text": chunk,
},
},
})
}
// 9. content_block_stop (text)
events = append(events, SseEvent{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": 2,
},
})
// 10. message_delta
outputTokens := (len(summary) + 3) / 4 // Simple estimation
events = append(events, SseEvent{
Event: "message_delta",
Data: map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]interface{}{
"output_tokens": outputTokens,
},
},
})
// 11. message_stop
events = append(events, SseEvent{
Event: "message_stop",
Data: map[string]interface{}{
"type": "message_stop",
},
})
return events
}
// generateSearchSummary generates a text summary of search results
func generateSearchSummary(query string, results *WebSearchResults) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query))
if results != nil && len(results.Results) > 0 {
for i, r := range results.Results {
sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title))
if r.Snippet != nil {
snippet := *r.Snippet
if len(snippet) > 200 {
snippet = snippet[:200] + "..."
}
sb.WriteString(fmt.Sprintf(" %s\n", snippet))
}
sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL))
}
} else {
sb.WriteString("No results found.\n")
}
sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.")
return sb.String()
}
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
// (server_tool_use + web_search_tool_result) without text summary or message termination.
// These events trigger Claude Code's search indicator UI.
// The caller is responsible for sending message_start before and message_delta/stop after.
func GenerateSearchIndicatorEvents(
query string,
toolUseID string,
searchResults *WebSearchResults,
startIndex int,
) []SseEvent {
events := make([]SseEvent, 0, 4)
// 1. content_block_start (server_tool_use)
events = append(events, SseEvent{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"id": toolUseID,
"type": "server_tool_use",
"name": "web_search",
"input": map[string]interface{}{},
},
},
})
// 2. content_block_delta (input_json_delta)
inputJSON, _ := json.Marshal(map[string]string{"query": query})
events = append(events, SseEvent{
Event: "content_block_delta",
Data: map[string]interface{}{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
},
})
// 3. content_block_stop (server_tool_use)
events = append(events, SseEvent{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": startIndex,
},
})
// 4. content_block_start (web_search_tool_result)
searchContent := make([]map[string]interface{}, 0)
if searchResults != nil {
for _, r := range searchResults.Results {
snippet := ""
if r.Snippet != nil {
snippet = *r.Snippet
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": r.Title,
"url": r.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
events = append(events, SseEvent{
Event: "content_block_start",
Data: map[string]interface{}{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"tool_use_id": toolUseID,
"content": searchContent,
},
},
})
// 5. content_block_stop (web_search_tool_result)
events = append(events, SseEvent{
Event: "content_block_stop",
Data: map[string]interface{}{
"type": "content_block_stop",
"index": startIndex + 1,
},
})
return events
}
// ══════════════════════════════════════════════════════════════════════════════
// Stream Analysis & Manipulation
// ══════════════════════════════════════════════════════════════════════════════
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
// It also suppresses duplicate message_start events (returns shouldForward=false).
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
//
// The data parameter is a single SSE "data:" line payload (JSON).
// Returns: adjusted data, shouldForward (false = skip this event).
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
if len(data) == 0 {
return data, true
}
// Quick check: parse the JSON
var event map[string]interface{}
if err := json.Unmarshal(data, &event); err != nil {
// Not valid JSON, pass through
return data, true
}
eventType, _ := event["type"].(string)
// Suppress duplicate message_start events
if eventType == "message_start" {
return data, false
}
// Adjust index for content_block events
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + offset
adjusted, err := json.Marshal(event)
if err != nil {
return data, true
}
return adjusted, true
}
}
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
return data, true
}
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
// and adjusts content block indices. Suppresses duplicate message_start events.
// Returns the adjusted chunk and whether it should be forwarded.
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
chunkStr := string(chunk)
// Fast path: if no "data:" prefix, pass through
if !strings.Contains(chunkStr, "data: ") {
return chunk, true
}
var result strings.Builder
hasContent := false
lines := strings.Split(chunkStr, "\n")
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
result.WriteString(line + "\n")
hasContent = true
continue
}
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
if !shouldForward {
// Skip this event and its preceding "event:" line
// Also skip the trailing empty line
continue
}
result.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
dataPayload = strings.TrimSpace(dataPayload)
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
// Skip both the event: and data: lines
i++ // skip the data: line too
continue
}
}
}
result.WriteString(line + "\n")
hasContent = true
} else {
result.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if !hasContent {
return nil, false
}
return []byte(result.String()), true
}
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
type BufferedStreamResult struct {
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
StopReason string
// WebSearchQuery is the extracted query if the model requested another web_search
WebSearchQuery string
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
WebSearchToolUseId string
// HasWebSearchToolUse indicates whether the model requested web_search
HasWebSearchToolUse bool
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
WebSearchToolUseIndex int
}
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
// This is used in the search loop to determine if the model wants another search round.
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
// Track tool use state across chunks
var currentToolName string
var currentToolIndex int = -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" || dataPayload == "" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
continue
}
eventType, _ := event["type"].(string)
switch eventType {
case "message_delta":
// Extract stop_reason from message_delta
if delta, ok := event["delta"].(map[string]interface{}); ok {
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
result.StopReason = sr
}
}
case "content_block_start":
// Detect tool_use content blocks
if cb, ok := event["content_block"].(map[string]interface{}); ok {
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
if name, ok := cb["name"].(string); ok {
currentToolName = strings.ToLower(name)
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
// Capture tool use ID for toolResults handshake
if id, ok := cb["id"].(string); ok {
result.WebSearchToolUseId = id
}
toolInputBuilder.Reset()
}
}
}
case "content_block_delta":
// Accumulate tool input JSON
if currentToolName != "" {
if delta, ok := event["delta"].(map[string]interface{}); ok {
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partial)
}
}
}
}
case "content_block_stop":
// Finalize tool use detection
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
// Extract query from accumulated input JSON
inputJSON := toolInputBuilder.String()
var input map[string]string
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
if q, ok := input["query"]; ok {
result.WebSearchQuery = q
}
}
log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery)
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
// when the proxy is handling the search loop internally.
// Also suppresses message_start and message_delta/message_stop events since those
// are managed by the outer handleWebSearchStream.
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
var filtered [][]byte
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
var resultBuilder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
// Skip [DONE] — the outer loop manages stream termination
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
resultBuilder.WriteString(line + "\n")
hasContent = true
continue
}
eventType, _ := event["type"].(string)
// Skip message_start (outer loop sends its own)
if eventType == "message_start" {
continue
}
// Skip message_delta and message_stop (outer loop manages these)
if eventType == "message_delta" || eventType == "message_stop" {
continue
}
// Check if this event belongs to the web_search tool_use block
if wsToolIndex >= 0 {
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
// Skip events for the web_search tool_use block
continue
}
}
// Apply index offset for remaining events
if indexOffset > 0 {
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
adjusted, err := json.Marshal(event)
if err == nil {
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
nextData := strings.TrimPrefix(lines[i+1], "data: ")
nextData = strings.TrimSpace(nextData)
var nextEvent map[string]interface{}
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
nextType, _ := nextEvent["type"].(string)
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
i++ // skip the data line
continue
}
if wsToolIndex >= 0 {
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
i++ // skip the data line
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else {
resultBuilder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if hasContent {
filtered = append(filtered, []byte(resultBuilder.String()))
}
}
return filtered
// ParseSearchResults extracts WebSearchResults from MCP response
func ParseSearchResults(response *McpResponse) *WebSearchResults {
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
return nil
}
content := response.Result.Content[0]
if content.ContentType != "text" {
return nil
}
var results WebSearchResults
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
return nil
}
return &results
}

View File

@@ -1,270 +0,0 @@
// Package claude provides web search handler for Kiro translator.
// This file implements the MCP API call and response handling.
package claude
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// Cached web_search tool description fetched from MCP tools/list.
// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure:
// - sync.Once prevents race conditions and deduplicates concurrent calls
// - On failure, a fresh sync.Once is swapped in to allow retry on next call
// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls
var (
cachedToolDescription atomic.Value // stores string
toolDescOnce atomic.Pointer[sync.Once]
fallbackFpOnce sync.Once
fallbackFp *kiroauth.Fingerprint
)
func init() {
toolDescOnce.Store(&sync.Once{})
}
// FetchToolDescription calls MCP tools/list to get the web_search tool description
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
// The httpClient parameter allows reusing a shared pooled HTTP client.
func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) {
toolDescOnce.Load().Do(func() {
handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs)
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody))
if err != nil {
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
toolDescOnce.Store(&sync.Once{}) // allow retry
return
}
// Reuse same headers as CallMcpAPI
handler.setMcpHeaders(req)
resp, err := handler.HTTPClient.Do(req)
if err != nil {
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
toolDescOnce.Store(&sync.Once{}) // allow retry
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil || resp.StatusCode != http.StatusOK {
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
toolDescOnce.Store(&sync.Once{}) // allow retry
return
}
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
var result struct {
Result *struct {
Tools []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"tools"`
} `json:"result"`
}
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
log.Warnf("kiro/websearch: failed to parse tools/list response")
toolDescOnce.Store(&sync.Once{}) // allow retry
return
}
for _, tool := range result.Result.Tools {
if tool.Name == "web_search" && tool.Description != "" {
cachedToolDescription.Store(tool.Description)
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
return // success — sync.Once stays "done", no more fetches
}
}
// web_search tool not found in response
toolDescOnce.Store(&sync.Once{}) // allow retry
})
}
// GetWebSearchDescription returns the cached web_search tool description,
// or empty string if not yet fetched. Lock-free via atomic.Value.
func GetWebSearchDescription() string {
if v := cachedToolDescription.Load(); v != nil {
return v.(string)
}
return ""
}
// WebSearchHandler handles web search requests via Kiro MCP API
type WebSearchHandler struct {
McpEndpoint string
HTTPClient *http.Client
AuthToken string
Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers
AuthAttrs map[string]string // optional, for custom headers from auth.Attributes
}
// NewWebSearchHandler creates a new WebSearchHandler.
// If httpClient is nil, a default client with 30s timeout is used.
// If fingerprint is nil, a random one-off fingerprint is generated.
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler {
if httpClient == nil {
httpClient = &http.Client{
Timeout: 30 * time.Second,
}
}
if fp == nil {
// Use a shared fallback fingerprint for callers without token context
fallbackFpOnce.Do(func() {
mgr := kiroauth.NewFingerprintManager()
fallbackFp = mgr.GetFingerprint("mcp-fallback")
})
fp = fallbackFp
}
return &WebSearchHandler{
McpEndpoint: mcpEndpoint,
HTTPClient: httpClient,
AuthToken: authToken,
Fingerprint: fp,
AuthAttrs: authAttrs,
}
}
// setMcpHeaders sets standard MCP API headers on the request,
// aligned with the GAR request pattern in kiro_executor.go.
func (h *WebSearchHandler) setMcpHeaders(req *http.Request) {
fp := h.Fingerprint
// 1. Content-Type & Accept (aligned with GAR)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
// 2. Kiro-specific headers (aligned with GAR)
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
// 3. Dynamic fingerprint headers
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
// 4. AWS SDK identifiers (casing aligned with GAR)
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// 5. Authentication
req.Header.Set("Authorization", "Bearer "+h.AuthToken)
// 6. Custom headers from auth attributes
util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs)
}
// mcpMaxRetries is the maximum number of retries for MCP API calls.
const mcpMaxRetries = 2
// CallMcpAPI calls the Kiro MCP API with the given request.
// Includes retry logic with exponential backoff for retryable errors,
// aligned with the GAR request retry pattern.
func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) {
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
}
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody))
var lastErr error
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<attempt) * time.Second
if backoff > 10*time.Second {
backoff = 10 * time.Second
}
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
time.Sleep(backoff)
}
req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
h.setMcpHeaders(req)
resp, err := h.HTTPClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("MCP API request failed: %w", err)
continue // network error → retry
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
continue // read error → retry
}
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
continue
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
}
var mcpResponse McpResponse
if err := json.Unmarshal(body, &mcpResponse); err != nil {
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
}
if mcpResponse.Error != nil {
code := -1
if mcpResponse.Error.Code != nil {
code = *mcpResponse.Error.Code
}
msg := "Unknown error"
if mcpResponse.Error.Message != nil {
msg = *mcpResponse.Error.Message
}
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
}
return &mcpResponse, nil
}
return nil, lastErr
}
// ParseSearchResults extracts WebSearchResults from MCP response
func ParseSearchResults(response *McpResponse) *WebSearchResults {
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
return nil
}
content := response.Result.Content[0]
if content.ContentType != "text" {
return nil
}
var results WebSearchResults
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
return nil
}
return &results
}

View File

@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
"execute_python": true,
}
// RequiredFieldsByTool maps tool names to their required fields.
// If any of these fields are missing, the tool input is considered truncated.
var RequiredFieldsByTool = map[string][]string{
"Write": {"file_path", "content"},
"write_to_file": {"path", "content"},
"fsWrite": {"path", "content"},
"create_file": {"path", "content"},
"edit_file": {"path"},
"apply_diff": {"path", "diff"},
"str_replace_editor": {"path", "old_str", "new_str"},
"Bash": {"command"},
"execute": {"command"},
"run_command": {"command"},
// RequiredFieldsByTool maps tool names to their required field groups.
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
// All groups must be satisfied for the tool input to be considered valid.
//
// Example:
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
var RequiredFieldsByTool = map[string][][]string{
"Write": {{"file_path"}, {"content"}},
"write_to_file": {{"path"}, {"content"}},
"fsWrite": {{"path"}, {"content"}},
"create_file": {{"path"}, {"content"}},
"edit_file": {{"path"}},
"apply_diff": {{"path"}, {"diff"}},
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
"Bash": {{"cmd", "command"}},
"execute": {{"command"}},
"run_command": {{"command"}},
}
// DetectTruncation checks if the tool use input appears to be truncated.
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
// Scenario 3: JSON parsed but critical fields are missing
if parsedInput != nil {
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
if hasRequirements {
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
if len(missingFields) > 0 {
info.IsTruncated = true
info.TruncationType = TruncationTypeMissingFields
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
return fields
}
// findMissingRequiredFields checks which required fields are missing from the parsed input.
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
// findMissingRequiredFields checks which required field groups are unsatisfied.
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
var missing []string
for _, field := range required {
if _, exists := parsed[field]; !exists {
missing = append(missing, field)
for _, group := range requiredGroups {
satisfied := false
for _, field := range group {
if _, exists := parsed[field]; exists {
satisfied = true
break
}
}
if !satisfied {
missing = append(missing, strings.Join(group, "/"))
}
}
return missing

View File

@@ -36,8 +36,14 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
if currentRole == lastRole {
// Merge content from current message into last message
mergedContent := mergeMessageContent(lastMsg, msg)
// Create a new merged message JSON
mergedMsg := createMergedMessage(lastRole, mergedContent)
var mergedToolCalls []interface{}
if currentRole == "assistant" {
// Preserve assistant tool_calls when adjacent assistant messages are merged.
mergedToolCalls = mergeToolCalls(lastMsg.Get("tool_calls"), msg.Get("tool_calls"))
}
// Create a new merged message JSON.
mergedMsg := createMergedMessage(lastRole, mergedContent, mergedToolCalls)
merged[len(merged)-1] = gjson.Parse(mergedMsg)
} else {
merged = append(merged, msg)
@@ -121,12 +127,34 @@ func blockToMap(block gjson.Result) map[string]interface{} {
return result
}
// createMergedMessage creates a JSON string for a merged message
func createMergedMessage(role string, content string) string {
// createMergedMessage creates a JSON string for a merged message.
// toolCalls is optional and only emitted for assistant role.
func createMergedMessage(role string, content string, toolCalls []interface{}) string {
msg := map[string]interface{}{
"role": role,
"content": json.RawMessage(content),
}
if role == "assistant" && len(toolCalls) > 0 {
msg["tool_calls"] = toolCalls
}
result, _ := json.Marshal(msg)
return string(result)
}
}
// mergeToolCalls combines tool_calls from two assistant messages while preserving order.
func mergeToolCalls(tc1, tc2 gjson.Result) []interface{} {
var merged []interface{}
if tc1.IsArray() {
for _, tc := range tc1.Array() {
merged = append(merged, tc.Value())
}
}
if tc2.IsArray() {
for _, tc := range tc2.Array() {
merged = append(merged, tc.Value())
}
}
return merged
}

View File

@@ -0,0 +1,106 @@
package common
import (
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseMessages(t *testing.T, raw string) []gjson.Result {
t.Helper()
parsed := gjson.Parse(raw)
if !parsed.IsArray() {
t.Fatalf("expected JSON array, got: %s", raw)
}
return parsed.Array()
}
func TestMergeAdjacentMessages_AssistantMergePreservesToolCalls(t *testing.T) {
messages := parseMessages(t, `[
{"role":"assistant","content":"part1"},
{
"role":"assistant",
"content":"part2",
"tool_calls":[
{
"id":"call_1",
"type":"function",
"function":{"name":"Read","arguments":"{}"}
}
]
},
{"role":"tool","tool_call_id":"call_1","content":"ok"}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 2 {
t.Fatalf("expected 2 messages after merge, got %d", len(merged))
}
assistant := merged[0]
if assistant.Get("role").String() != "assistant" {
t.Fatalf("expected first message role assistant, got %q", assistant.Get("role").String())
}
toolCalls := assistant.Get("tool_calls")
if !toolCalls.IsArray() || len(toolCalls.Array()) != 1 {
t.Fatalf("expected assistant.tool_calls length 1, got: %s", toolCalls.Raw)
}
if toolCalls.Array()[0].Get("id").String() != "call_1" {
t.Fatalf("expected tool call id call_1, got %q", toolCalls.Array()[0].Get("id").String())
}
contentRaw := assistant.Get("content").Raw
if !strings.Contains(contentRaw, "part1") || !strings.Contains(contentRaw, "part2") {
t.Fatalf("expected merged content to contain both parts, got: %s", contentRaw)
}
if merged[1].Get("role").String() != "tool" {
t.Fatalf("expected second message role tool, got %q", merged[1].Get("role").String())
}
}
func TestMergeAdjacentMessages_AssistantMergeCombinesMultipleToolCalls(t *testing.T) {
messages := parseMessages(t, `[
{
"role":"assistant",
"content":"first",
"tool_calls":[
{"id":"call_1","type":"function","function":{"name":"Read","arguments":"{}"}}
]
},
{
"role":"assistant",
"content":"second",
"tool_calls":[
{"id":"call_2","type":"function","function":{"name":"Write","arguments":"{}"}}
]
}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 1 {
t.Fatalf("expected 1 message after merge, got %d", len(merged))
}
toolCalls := merged[0].Get("tool_calls").Array()
if len(toolCalls) != 2 {
t.Fatalf("expected 2 merged tool calls, got %d", len(toolCalls))
}
if toolCalls[0].Get("id").String() != "call_1" || toolCalls[1].Get("id").String() != "call_2" {
t.Fatalf("unexpected merged tool call ids: %q, %q", toolCalls[0].Get("id").String(), toolCalls[1].Get("id").String())
}
}
func TestMergeAdjacentMessages_ToolMessagesRemainUnmerged(t *testing.T) {
messages := parseMessages(t, `[
{"role":"tool","tool_call_id":"call_1","content":"r1"},
{"role":"tool","tool_call_id":"call_2","content":"r2"}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 2 {
t.Fatalf("expected tool messages to remain separate, got %d", len(merged))
}
}

View File

@@ -234,16 +234,16 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
// rather than inline <thinking> tags in assistantResponseEvent.
// We use a high max_thinking_length to allow extensive reasoning.
// Use a conservative thinking budget to reduce latency/cost spikes in long sessions.
if thinkingEnabled {
thinkingHint := `<thinking_mode>enabled</thinking_mode>
<max_thinking_length>200000</max_thinking_length>`
<max_thinking_length>16000</max_thinking_length>`
if systemPrompt != "" {
systemPrompt = thinkingHint + "\n\n" + systemPrompt
} else {
systemPrompt = thinkingHint
}
log.Debugf("kiro-openai: injected thinking prompt (official mode)")
log.Infof("kiro-openai: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
}
// Process messages and build history
@@ -578,6 +578,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Truncate history if too long to prevent Kiro API errors
history = truncateHistoryIfNeeded(history)
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
return history, currentUserMsg, currentToolResults
}
@@ -593,6 +594,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
return history[len(history)-kiroMaxHistoryMessages:]
}
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
// Remove tool results with no matching tool_use in retained history.
// This happens after truncation when the assistant turn that produced tool_use
// is dropped but a later user/tool_result survives.
validToolUseIDs := make(map[string]bool)
for _, h := range history {
if h.AssistantResponseMessage == nil {
continue
}
for _, tu := range h.AssistantResponseMessage.ToolUses {
validToolUseIDs[tu.ToolUseID] = true
}
}
for i, h := range history {
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
continue
}
ctx := h.UserInputMessage.UserInputMessageContext
if len(ctx.ToolResults) == 0 {
continue
}
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
for _, tr := range ctx.ToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
}
ctx.ToolResults = filtered
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
h.UserInputMessage.UserInputMessageContext = nil
}
}
if len(currentToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(currentToolResults))
for _, tr := range currentToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
}
if len(filtered) != len(currentToolResults) {
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
}
currentToolResults = filtered
}
return history, currentToolResults
}
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
content := msg.Get("content")
@@ -831,7 +887,6 @@ func hasThinkingTagInBody(body []byte) bool {
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
}
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
// OpenAI tool_choice values:
// - "none": Don't use any tools

Some files were not shown because too many files have changed in this diff Show More