Compare commits

...

97 Commits

Author SHA1 Message Date
Luis Pater
21d2329947 Merge pull request #261 from router-for-me/plus
v6.8.26
2026-02-23 00:15:36 +08:00
Luis Pater
0993413bab Merge branch 'main' into plus 2026-02-23 00:15:22 +08:00
Luis Pater
713388dd7b Fixed: #1675
fix(gemini): add model definitions for Gemini 3.1 Pro High and Image
2026-02-23 00:12:57 +08:00
Luis Pater
e6c7af0fa9 Merge pull request #1522 from soilSpoon/feature/canceled
feature(proxy): Adds special handling for client cancellations in proxy error handler
2026-02-22 22:02:59 +08:00
Luis Pater
837aa6e3aa Merge branch 'router-for-me:main' into main 2026-02-22 21:52:53 +08:00
Luis Pater
d210be06c2 fix(gemini): update min Thinking value and add Gemini 3.1 Pro Preview model definition 2026-02-22 21:51:32 +08:00
Luis Pater
af8e9ef458 Merge branch 'router-for-me:main' into main 2026-02-21 21:09:52 +08:00
Luis Pater
cec6f993ad Merge pull request #256 from kavore/fix/oauth-copilot-claude-aliases
fix: add default copilot claude model aliases for oauth routing
2026-02-21 21:09:43 +08:00
Luis Pater
950de29f48 Merge pull request #255 from ladeng07/main
feat(registry): add GPT-4o model variants for GitHub Copilot
2026-02-21 21:09:06 +08:00
Luis Pater
d6ec33e8e1 Merge pull request #1662 from matchch/contribute/cache-user-id
feat: add cache-user-id toggle for Claude cloaking
2026-02-21 20:51:30 +08:00
Luis Pater
081cfe806e fix(gemini): correct Created timestamps for Gemini 3.1 Pro Preview model definitions 2026-02-21 20:47:47 +08:00
hkfires
c1c62a6c04 feat(gemini): add Gemini 3.1 Pro Preview model definitions 2026-02-21 20:42:29 +08:00
matchch
2fdf5d2793 feat: add cache-user-id toggle for Claude cloaking
Default to generating a fresh random user_id per request instead of
reusing cached IDs. Add cache-user-id config option to opt in to the
previous caching behavior.

- Add CacheUserID field to CloakConfig
- Extract user_id cache logic to dedicated file
- Generate fresh user_id by default, cache only when enabled
- Add tests for both paths
2026-02-21 12:31:20 +08:00
kavore
b3da00d2ed fix: add default copilot claude model aliases for oauth routing 2026-02-20 21:59:21 +03:00
LMark
740277a9f2 refactor(registry): deduplicate GitHub Copilot GPT-4o model definitions 2026-02-21 02:32:06 +08:00
LMark
f91807b6b9 Add GPT-4o model variants while keeping Gemini 3.1 Pro preview 2026-02-21 01:41:01 +08:00
Luis Pater
57d18bb226 Merge branch 'router-for-me:main' into main 2026-02-20 22:42:01 +08:00
Luis Pater
10b9c6cb8a Merge pull request #252 from DragonBaiMo/fix/kiro-thinking-stream-dedup
fix(kiro): stop duplicated thinking on OpenAI and preserve Claude multi-turn thinking
2026-02-20 22:41:48 +08:00
Luis Pater
b24786f8a7 Merge pull request #250 from TonyRL/feat/copilot-gemini-3.1
feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider
2026-02-20 22:40:41 +08:00
Luis Pater
7b0eb41ebc Merge pull request #1660 from Grivn/fix/claude-token-url
fix(claude): use api.anthropic.com for OAuth token exchange
2026-02-20 21:52:08 +08:00
DragonBaiMo
70949929db fix(kiro): deduplicate thinking stream emission 2026-02-20 20:34:40 +08:00
DragonBaiMo
7c9c89dace fix(kiro): keep thinking enabled across request formats 2026-02-20 20:34:40 +08:00
Grivn
ef5901c81b fix(claude): use api.anthropic.com for OAuth token exchange
console.anthropic.com is now protected by a Cloudflare managed challenge
that blocks all non-browser POST requests to /v1/oauth/token, causing
`-claude-login` to fail with a 403 error.

Switch to api.anthropic.com which hosts the same OAuth token endpoint
without the Cloudflare managed challenge.

Fixes #1659

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 20:11:27 +08:00
Luis Pater
d4829c82f7 Merge pull request #1652 from thebtf/fix/claude-translator-arguments
fix(translator): handle tool call arguments in codex→claude streaming translator
2026-02-20 19:50:20 +08:00
Luis Pater
a5f4166a9b Merge pull request #1644 from possible055/main
feat: add Gemini 3.1 Pro Preview model definition
2026-02-20 19:44:59 +08:00
Tony
f2b1ec4f9e feat(registry): add Gemini 3.1 Pro to GitHub Copilot provider 2026-02-20 04:23:42 +08:00
Kirill Turanskiy
1cc21cc45b fix: prevent duplicate function call arguments when delta events precede done
Non-spark codex models (gpt-5.3-codex, gpt-5.2-codex) stream function call
arguments via multiple delta events followed by a done event. The done handler
unconditionally emitted the full arguments, duplicating what deltas already
streamed. This produced invalid double JSON that Claude Code couldn't parse,
causing tool calls to fail with missing parameters and infinite retry loops.

Add HasReceivedArgumentsDelta flag to track whether delta events were received.
The done handler now only emits arguments when no deltas preceded it (spark
models), while delta-based streaming continues to work for non-spark models.
2026-02-19 23:18:14 +03:00
Kirill Turanskiy
07cf616e2b fix: handle response.function_call_arguments.done in codex→claude streaming translator
Some Codex models (e.g. gpt-5.3-codex-spark) send function call arguments
in a single "done" event without preceding "delta" events. The streaming
translator only handled "delta" events, causing tool call arguments to be
lost — resulting in empty tool inputs and infinite retry loops in clients
like Claude Code.

Emit the full arguments from the "done" event as a single input_json_delta
so downstream clients receive the complete tool input.
2026-02-19 23:18:14 +03:00
Luis Pater
2b8c466e88 refactor(executor, handlers): replace channel-based streams with StreamResult for consistency
- Updated `ExecuteStream` functions in executors to use `StreamResult` instead of channels.
- Enhanced upstream header handling in OpenAI handlers.
- Improved maintainability and alignment across executors and handlers.
2026-02-19 22:07:14 +08:00
Luis Pater
ca2174ea48 Merge pull request #249 from router-for-me/plus
v6.8.22
2026-02-19 21:58:42 +08:00
Luis Pater
c09fb2a79d Merge branch 'main' into plus 2026-02-19 21:58:04 +08:00
Luis Pater
4445a165e9 test(handlers): add tests for passthrough headers behavior in WriteErrorResponse 2026-02-19 21:49:44 +08:00
Luis Pater
e92e2af71a Merge branch 'codex/pr-1626' into dev 2026-02-19 21:33:23 +08:00
Luis Pater
a6bdd9a652 feat: add passthrough headers configuration
- Introduced `passthrough-headers` option in configuration to control forwarding of upstream response headers.
- Updated handlers to respect the passthrough headers setting.
- Added tests to verify behavior when passthrough is enabled or disabled.
2026-02-19 21:31:29 +08:00
Luis Pater
349a6349b3 Merge pull request #1645 from tinyc0der/fix/antigravity-tool-result-json
fix(antigravity): prevent invalid JSON when tool_result has no content
2026-02-19 21:01:25 +08:00
TinyCoder
00822770ec fix(antigravity): prevent invalid JSON when tool_result has no content
sjson.SetRaw with an empty string produces malformed JSON (e.g. "result":}).
This happens when a Claude tool_result block has no content field, causing
functionResponseResult.Raw to be "". Guard against this by falling back to
sjson.Set with an empty string only when .Raw is empty.
2026-02-19 17:08:39 +07:00
apparition
1a0ceda0fc feat: add Gemini 3.1 Pro Preview model definition 2026-02-19 17:43:08 +08:00
Luis Pater
72add453d2 docs: add OmniRoute to README 2026-02-19 13:23:25 +08:00
Luis Pater
2789396435 fix: ensure connection-scoped headers are filtered in upstream requests
- Added `connectionScopedHeaders` utility to respect "Connection" header directives.
- Updated `FilterUpstreamHeaders` to remove connection-scoped headers dynamically.
- Refactored and tested upstream header filtering with additional validations.
- Adjusted upstream header handling during retries to replace headers safely.
2026-02-19 13:19:10 +08:00
Luis Pater
61da7bd981 Merge PR #1626 into codex/pr-1626 2026-02-19 04:49:14 +08:00
Luis Pater
ae4c502792 Merge pull request #248 from router-for-me/plus
v6.8.21
2026-02-19 04:42:44 +08:00
Luis Pater
ec6068060b Merge branch 'main' into plus 2026-02-19 04:42:35 +08:00
Luis Pater
ecb01d3dcd Merge pull request #244 from PancakeZik/feat/sonnet-4-6
feat: add Claude Sonnet 4.6 model support for Kiro provider
2026-02-19 04:31:20 +08:00
Luis Pater
22c0c00bd4 Merge branch 'main' into feat/sonnet-4-6 2026-02-19 04:31:07 +08:00
Luis Pater
9eb3e7a6c4 Merge pull request #243 from gl11tchy/feat/claude-sonnet-4-6
feat(registry): add Claude Sonnet 4.6 model definitions
2026-02-19 04:29:39 +08:00
Luis Pater
357c191510 Merge pull request #242 from ultraplan-bit/main
Improve Copilot provider based on ericc-ch/copilot-api comparison
2026-02-19 04:27:02 +08:00
Luis Pater
5db244af76 Merge pull request #240 from TonyRL/feat/copilot-sonnet-4.6
feat(registry): add Sonnet 4.6 to GitHub Copilot provider
2026-02-19 04:26:28 +08:00
Luis Pater
dc375d1b74 Merge pull request #239 from TonyRL/feat/copilot-codex-5.3
feat(registry): add GPT-5.3 Codex to GitHub Copilot provider
2026-02-19 04:25:25 +08:00
Luis Pater
9c040445af Merge pull request #1635 from thebtf/fix/openai-translator-tool-streaming
fix: handle tool call argument streaming in Codex→OpenAI translator
2026-02-19 04:22:12 +08:00
Luis Pater
fff866424e Merge pull request #1628 from thebtf/fix/masquerading-headers
fix: update Claude masquerading headers and configurable defaults
2026-02-19 04:19:59 +08:00
Luis Pater
2d12becfd6 Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping
fix: clamp reasoning_effort to valid OpenAI-format values
2026-02-19 04:15:19 +08:00
Luis Pater
252f7e0751 Merge pull request #1625 from thebtf/feat/tool-prefix-config
feat: add per-auth tool_prefix_disabled option
2026-02-19 04:07:22 +08:00
Luis Pater
b2b17528cb Merge branch 'pr-1624' into dev
# Conflicts:
#	internal/runtime/executor/claude_executor.go
#	internal/runtime/executor/claude_executor_test.go
2026-02-19 04:05:04 +08:00
Luis Pater
55f938164b Merge pull request #1618 from alexey-yanchenko/fix/completions-usage
Fix empty usage in /v1/completions
2026-02-19 03:57:11 +08:00
Luis Pater
76294f0c59 Merge pull request #1608 from thebtf/fix/tool-reference-proxy-prefix-mainline
fix: add proxy_ prefix handling for tool_reference content blocks
2026-02-19 03:50:34 +08:00
Luis Pater
2bcee78c6e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:19:18 +08:00
Luis Pater
7fe8246a9f Merge branch 'tui' into dev 2026-02-19 03:18:24 +08:00
Luis Pater
93fe58e31e feat(tui): add standalone mode and API-based log polling
- Implemented `--standalone` mode to launch an embedded server for TUI.
- Enhanced TUI client to support API-based log polling when log hooks are unavailable.
- Added authentication gate for password input and connection handling.
- Improved localization and UX for logs, authentication, and status bar rendering.
2026-02-19 03:18:08 +08:00
Luis Pater
e5b5dc870f chore(executor): remove unused Openai-Beta header from Codex executor 2026-02-19 02:19:48 +08:00
Luis Pater
a54877c023 Merge branch 'dev' 2026-02-19 02:03:41 +08:00
Luis Pater
bb86a0c0c4 feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage.
- Added WebSocket-based Codex executor to support Responses API upgrade.
- Updated middleware logic to selectively capture request bodies for memory efficiency.
- Enhanced Codex configuration handling with new WebSocket attributes.
2026-02-19 01:57:02 +08:00
Kirill Turanskiy
5fa23c7f41 fix: handle tool call argument streaming in Codex→OpenAI translator
The OpenAI Chat Completions translator was silently dropping
response.function_call_arguments.delta and
response.function_call_arguments.done Codex SSE events, meaning
tool call arguments were never streamed incrementally to clients.

Add proper handling mirroring the proven Claude translator pattern:

- response.output_item.added: announce tool call (id, name, empty args)
- response.function_call_arguments.delta: stream argument chunks
- response.function_call_arguments.done: emit full args if no deltas
- response.output_item.done: defensive fallback for backward compat

State tracking via HasReceivedArgumentsDelta and HasToolCallAnnounced
ensures no duplicate argument emission and correct behavior for models
like codex-spark that skip delta events entirely.
2026-02-18 19:09:05 +03:00
gl11tchy
f9a09b7f23 style: sort model entries per review feedback 2026-02-18 15:06:28 +00:00
Joao
b0cde626fe feat: add Claude Sonnet 4.6 model support for Kiro provider 2026-02-18 13:51:23 +00:00
gl11tchy
e42ef9a95d feat(registry): add Claude Sonnet 4.6 model definitions
Add claude-sonnet-4-6 to:
- Claude OAuth provider (model_definitions_static_data.go)
- Antigravity model config (thinking + non-thinking entries)
- GitHub Copilot provider (model_definitions.go)

Ref: https://docs.anthropic.com/en/docs/about-claude/models
2026-02-18 13:43:22 +00:00
ultraplan-bit
abf1629ec7 Merge branch 'main' of https://github.com/ultraplan-bit/CLIProxyAPIPlus 2026-02-18 08:56:06 +08:00
Kirill Turanskiy
73dc0b10b8 fix: update Claude masquerading headers and make them configurable
Update hardcoded X-Stainless-* and User-Agent defaults to match
Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (verified via
diagnostic proxy capture 2026-02-17).

Changes:
- X-Stainless-Os/Arch: dynamic via runtime.GOOS/GOARCH
- X-Stainless-Package-Version: 0.55.1 → 0.74.0
- X-Stainless-Timeout: 60 → 600
- User-Agent: claude-cli/1.0.83 (external, cli) → claude-cli/2.1.44 (external, sdk-cli)

Add claude-header-defaults config section so values can be updated
without recompilation when Claude Code releases new versions.
2026-02-18 03:38:51 +03:00
Kirill Turanskiy
2ea95266e3 fix: clamp reasoning_effort to valid OpenAI-format values
CPA-internal thinking levels like 'xhigh' and 'minimal' are not accepted
by OpenAI-format providers (MiniMax, etc.). The OpenAI applier now maps
non-standard levels to the nearest valid reasoning_effort value before
writing to the request body:

  xhigh   → high
  minimal → low
  auto    → medium
2026-02-18 03:36:42 +03:00
Tony
922d4141c0 feat(registry): add Sonnet 4.6 to GitHub Copilot provider 2026-02-18 05:17:23 +08:00
Kirill Turanskiy
1f8f198c45 feat: passthrough upstream response headers to clients
CPA previously stripped ALL response headers from upstream AI provider
APIs, preventing clients from seeing rate-limit info, request IDs,
server-timing and other useful headers.

Changes:
- Add Headers field to Response and StreamResult structs
- Add FilterUpstreamHeaders helper (hop-by-hop + security denylist)
- Add WriteUpstreamHeaders helper (respects CPA-set headers)
- ExecuteWithAuthManager/ExecuteCountWithAuthManager now return headers
- ExecuteStreamWithAuthManager returns headers from initial connection
- All 11 provider executors populate Response.Headers
- All handler call sites write filtered upstream headers before response

Filtered headers (not forwarded):
- RFC 7230 hop-by-hop: Connection, Transfer-Encoding, Keep-Alive, etc.
- Security: Set-Cookie
- CPA-managed: Content-Length, Content-Encoding
2026-02-18 00:16:22 +03:00
Tony
c55275342c feat(registry): add GPT-5.3 Codex to GitHub Copilot provider 2026-02-18 03:04:27 +08:00
Kirill Turanskiy
9261b0c20b feat: add per-auth tool_prefix_disabled option
Allow disabling the proxy_ tool name prefix on a per-account basis.
Users who route their own Anthropic account through CPA can set
"tool_prefix_disabled": true in their OAuth auth JSON to send tool
names unchanged to Anthropic.

Default behavior is fully preserved — prefix is applied unless
explicitly disabled.

Changes:
- Add ToolPrefixDisabled() accessor to Auth (reads metadata key
  "tool_prefix_disabled" or "tool-prefix-disabled")
- Gate all 6 prefix apply/strip points with the new flag
- Add unit tests for the accessor
2026-02-17 21:48:19 +03:00
Kirill Turanskiy
7cc725496e fix: skip proxy_ prefix for built-in tools in message history
The proxy_ prefix logic correctly skips built-in tools (those with a
non-empty "type" field) in tools[] definitions but does not skip them
in messages[].content[] tool_use blocks or tool_choice. This causes
web_search in conversation history to become proxy_web_search, which
Anthropic does not recognize.

Fix: collect built-in tool names from tools[] into a set and also
maintain a hardcoded fallback set (web_search, code_execution,
text_editor, computer) for cases where the built-in tool appears in
history but not in the current request's tools[] array. Skip prefixing
in messages and tool_choice when name matches a built-in.
2026-02-17 21:42:32 +03:00
ultraplan-bit
5726a99c80 Improve Copilot provider based on ericc-ch/copilot-api comparison
- Fix X-Initiator detection: check for any assistant/tool role
  in messages instead of only the last message role, matching
  the correct agent detection for multi-turn tool conversations
- Add x-github-api-version: 2025-04-01 header for API compatibility
- Support Business/Enterprise accounts by using Endpoints.API from
  the Copilot token response instead of hardcoded base URL
- Fix Responses API vision detection: detect vision content before
  input normalization removes the messages array
- Add 8 test cases covering the above fixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:11:17 +08:00
ultraplan-bit
b5756bf729 Fix Copilot 0x model incorrectly consuming premium requests
Change Openai-Intent header from "conversation-edits" to
"conversation-panel" to avoid triggering GitHub's premium
execution path, which caused included models (0x multiplier)
to be billed as premium requests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 21:17:18 +08:00
Alexey Yanchenko
709d999f9f Add usage to /v1/completions 2026-02-17 17:21:03 +07:00
Kirill Turanskiy
24c18614f0 fix: skip built-in tools in tool_reference prefix + refactor to switch
- Collect built-in tool names (those with a "type" field like
  web_search, code_execution) and skip prefixing tool_reference
  blocks that reference them, preventing name mismatch.
- Refactor if-else if chains to switch statements in all three
  prefix functions for idiomatic Go style.
2026-02-16 19:37:11 +03:00
Kirill Turanskiy
603f06a762 fix: handle tool_reference nested inside tool_result.content[]
tool_reference blocks can appear nested inside tool_result.content[]
arrays, not just at the top level of messages[].content[]. The prefix
logic now iterates into tool_result blocks with array content to find
and prefix/strip nested tool_reference.tool_name fields.
2026-02-16 19:06:24 +03:00
Kirill Turanskiy
98f0a3e3bd fix: add proxy_ prefix handling for tool_reference content blocks (#1)
applyClaudeToolPrefix, stripClaudeToolPrefixFromResponse, and
stripClaudeToolPrefixFromStreamLine now handle "tool_reference" blocks
(field "tool_name") in addition to "tool_use" blocks (field "name").

Without this fix, tool_reference blocks in conversation history retain
their original unprefixed names while tool definitions carry the proxy_
prefix, causing Anthropic API 400 errors: "Tool reference 'X' not found
in available tools."

Co-authored-by: Kirill Turanskiy <kt@novamedia.ru>
2026-02-16 19:06:24 +03:00
Luis Pater
e186ccb0d4 Merge pull request #234 from detroittommy879/feature/add-kilocode-provider
Add Kilo Code provider with dynamic model fetching
2026-02-16 23:54:29 +08:00
Luis Pater
8fc0b08b70 Merge pull request #233 from ultraplan-bit/fix/copilot-codex-responses-translation
Fix Copilot codex model Responses API translation for Claude Code
2026-02-16 23:51:42 +08:00
DetroitTommy
d328e54e4b refactor(kilo): address code review suggestions for robustness 2026-02-15 17:26:29 -05:00
DetroitTommy
5a7932cba4 Added Kilo Code as a provider, with auth. It fetches the free models, tested them (works), for paid models someone will have to experiment so only the free ones are known to work 2026-02-15 14:54:20 -05:00
DetroitTommy
1dbeb0827a added kilocode auth, needs adjusting 2026-02-15 13:44:26 -05:00
lhpqaq
2c8821891c fix(tui): update with review 2026-02-16 00:24:25 +08:00
haopeng
0a2555b0f3 Update internal/tui/auth_tab.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-16 00:11:31 +08:00
lhpqaq
020df41efe chore(tui): update readme, fix usage 2026-02-16 00:04:04 +08:00
ultraplan-bit
f8f8cf17ce Fix Copilot codex model Responses API translation for Claude Code
- Add response.function_call_arguments.delta handler for tool call parameters
- Rewrite normalizeGitHubCopilotResponsesInput to produce structured input
  array (message/function_call/function_call_output) instead of flattened
  text, fixing infinite loop in multi-turn tool-use conversations
- Skip flattenAssistantContent for messages containing tool_use blocks,
  preventing function_call items from being destroyed
- Add reasoning/thinking stream & non-stream support
- Fix stop_reason mapping (max_tokens/stop) and cached token reporting
- Update test to match new array-based input format

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:04:45 +08:00
lhpqaq
f31f7f701a feat(tui): add i18n 2026-02-15 15:42:59 +08:00
lhpqaq
54ad7c1b6b feat(tui): add manager tui 2026-02-15 14:52:40 +08:00
이대희
a45c6defa7 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-13 15:07:32 +09:00
이대희
40bee3e8d9 Merge branch 'main' into feature/canceled 2026-02-13 13:37:55 +09:00
이대희
93147dddeb Improves error handling for canceled requests
Adds explicit handling for context.Canceled errors in the reverse proxy error handler to return 499 status code without logging, which is more appropriate for client-side cancellations during polling.

Also adds a test case to verify this behavior.
2026-02-12 10:39:45 +09:00
이대희
c0f9b15a58 Merge remote-tracking branch 'upstream/main' into feature/canceled 2026-02-12 10:33:49 +09:00
이대희
6f2fbdcbae Update internal/api/modules/amp/proxy.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-12 10:30:05 +09:00
이대희
ce0c6aa82b Update internal/api/modules/amp/proxy.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-10 19:07:49 +09:00
이대희
3c85d2a4d7 feature(proxy): Adds special handling for client cancellations in proxy error handler
Silences logging for client cancellations during polling to reduce noise in logs.
Client-side cancellations are common during long-running operations and should not be treated as errors.
2026-02-10 18:02:08 +09:00
95 changed files with 11094 additions and 332 deletions

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@ cli-proxy-api
cliproxy
*.exe
# Configuration
config.yaml
.env

View File

@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
@@ -26,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
@@ -72,6 +74,7 @@ func main() {
var codexLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
@@ -88,6 +91,8 @@ func main() {
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
var noIncognito bool
var useIncognito bool
@@ -96,6 +101,7 @@ func main() {
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
@@ -114,6 +120,8 @@ func main() {
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
@@ -499,6 +507,8 @@ func main() {
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
} else if kiloLogin {
cmd.DoKiloLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
@@ -536,15 +546,89 @@ func main() {
cmd.WaitForCloudDeploy()
return
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
// 初始化并启动 Kiro token 后台刷新
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
cmd.StartService(cfg, configFilePath, password)
}
}

View File

@@ -1,6 +1,6 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
host: ''
# Server port
port: 8317
@@ -8,8 +8,8 @@ port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
cert: ''
key: ''
# Management API settings
remote-management:
@@ -20,22 +20,22 @@ remote-management:
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
secret-key: ""
secret-key: ''
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
# Authentication directory (supports ~ for home directory)
auth-dir: "~/.cli-proxy-api"
auth-dir: '~/.cli-proxy-api'
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
- 'your-api-key-1'
- 'your-api-key-2'
- 'your-api-key-3'
# Enable debug logging
debug: false
@@ -43,7 +43,7 @@ debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
addr: '127.0.0.1:8316'
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
@@ -68,11 +68,15 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
proxy-url: ''
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
@@ -86,7 +90,7 @@ quota-exceeded:
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
strategy: 'round-robin' # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
@@ -160,6 +164,15 @@ nonstream-keepalive-interval: 0
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
@@ -171,6 +184,21 @@ nonstream-keepalive-interval: 0
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
# Kilocode (OAuth-based code assistant)
# Note: Kilocode uses OAuth device flow authentication.
# Use the CLI command: ./server --kilo-login
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
# oauth-model-alias:
# kilo:
# - name: "minimax/minimax-m2.5:free"
# alias: "minimax-m2.5"
# - name: "z-ai/glm-5:free"
# alias: "glm-5"
# oauth-excluded-models:
# kilo:
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
# - "*:free" # wildcard matching suffix (e.g. all free models)
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.

View File

@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
return clipexec.Response{}, errors.New("count tokens not implemented")
}
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
ch := make(chan clipexec.StreamChunk, 1)
go func() {
defer close(ch)
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
}()
return ch, nil
return &clipexec.StreamResult{Chunks: ch}, nil
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {

View File

@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}

21
go.mod
View File

@@ -4,6 +4,10 @@ go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/fxamacker/cbor/v2 v2.9.0
github.com/gin-gonic/gin v1.10.1
@@ -33,8 +37,16 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
@@ -42,6 +54,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -58,19 +71,27 @@ require (
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect

45
go.sum
View File

@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
@@ -33,6 +57,8 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
@@ -101,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
@@ -114,6 +146,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -124,6 +162,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -161,6 +201,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -168,12 +210,15 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -29,6 +29,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -813,6 +814,87 @@ func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -2733,3 +2815,88 @@ func generateKiroPKCE() (verifier, challenge string, err error) {
return verifier, challenge, nil
}
func (h *Handler) RequestKiloToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing Kilo authentication...")
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
kilocodeAuth := kilo.NewKiloAuth()
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
RegisterOAuthSession(state, "kilo")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", err)
return
}
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
log.Warnf("Failed to fetch profile: %v", err)
profile = &kilo.Profile{Email: status.UserEmail}
}
var orgID string
if len(profile.Orgs) > 0 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
record := &coreauth.Auth{
ID: fileName,
Provider: "kilo",
FileName: fileName,
Storage: ts,
Metadata: map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kilo")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": resp.VerificationURL,
"state": state,
"user_code": resp.Code,
"verification_uri": resp.VerificationURL,
})
}

View File

@@ -15,10 +15,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
if c.Request.Method == http.MethodGet {
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c)
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture request body
var body []byte
if c.Request.Body != nil {
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {

View File

@@ -0,0 +1,138 @@
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
bodyStr := string(w.requestInfo.Body)
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}

View File

@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Don't log as error for context canceled - it's usually client closing connection
if errors.Is(err, context.Canceled) {
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
return
} else {
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
}

View File

@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -285,8 +285,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available.
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
@@ -329,6 +330,7 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
@@ -642,6 +644,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
@@ -649,6 +652,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)

View File

@@ -20,7 +20,7 @@ import (
// OAuth configuration constants for Claude/Anthropic
const (
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)

View File

@@ -0,0 +1,168 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
const (
// BaseURL is the base URL for the Kilo AI API.
BaseURL = "https://api.kilo.ai/api"
)
// DeviceAuthResponse represents the response from initiating device flow.
type DeviceAuthResponse struct {
Code string `json:"code"`
VerificationURL string `json:"verificationUrl"`
ExpiresIn int `json:"expiresIn"`
}
// DeviceStatusResponse represents the response when polling for device flow status.
type DeviceStatusResponse struct {
Status string `json:"status"`
Token string `json:"token"`
UserEmail string `json:"userEmail"`
}
// Profile represents the user profile from Kilo AI.
type Profile struct {
Email string `json:"email"`
Orgs []Organization `json:"organizations"`
}
// Organization represents a Kilo AI organization.
type Organization struct {
ID string `json:"id"`
Name string `json:"name"`
}
// Defaults represents default settings for an organization or user.
type Defaults struct {
Model string `json:"model"`
}
// KiloAuth provides methods for handling the Kilo AI authentication flow.
type KiloAuth struct {
client *http.Client
}
// NewKiloAuth creates a new instance of KiloAuth.
func NewKiloAuth() *KiloAuth {
return &KiloAuth{
client: &http.Client{Timeout: 30 * time.Second},
}
}
// InitiateDeviceFlow starts the device authentication flow.
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
}
var data DeviceAuthResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return &data, nil
}
// PollForToken polls for the device flow completion.
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var data DeviceStatusResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
switch data.Status {
case "approved":
return &data, nil
case "denied", "expired":
return nil, fmt.Errorf("device flow %s", data.Status)
case "pending":
continue
default:
return nil, fmt.Errorf("unknown status: %s", data.Status)
}
}
}
}
// GetProfile fetches the user's profile.
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
if err != nil {
return nil, fmt.Errorf("failed to create get profile request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
}
var profile Profile
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
return nil, err
}
return &profile, nil
}
// GetDefaults fetches default settings for an organization.
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
url := BaseURL + "/defaults"
if orgID != "" {
url = BaseURL + "/organizations/" + orgID + "/defaults"
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
}
var defaults Defaults
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
return nil, err
}
return &defaults, nil
}

View File

@@ -0,0 +1,60 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// KiloTokenStorage stores token information for Kilo AI authentication.
type KiloTokenStorage struct {
// Token is the Kilo access token.
Token string `json:"kilocodeToken"`
// OrganizationID is the Kilo organization ID.
OrganizationID string `json:"kilocodeOrganizationId"`
// Model is the default model to use.
Model string `json:"kilocodeModel"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Type indicates the authentication provider type, always "kilo" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "kilo"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("failed to close file: %v", errClose)
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used to persist Kilo credentials.
func CredentialFileName(email string) string {
return fmt.Sprintf("kilo-%s.json", email)
}

View File

@@ -22,6 +22,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,54 @@
package cmd
import (
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
// It initiates the device-based authentication process for Kilo AI services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = func(prompt string) (string, error) {
fmt.Print(prompt)
var value string
fmt.Scanln(&value)
return strings.TrimSpace(value), nil
}
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
if err != nil {
fmt.Printf("Kilo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Kilo authentication successful!")
}

View File

@@ -55,6 +55,34 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {

View File

@@ -97,6 +97,10 @@ type Config struct {
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
@@ -130,6 +134,15 @@ type Config struct {
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
@@ -301,6 +314,10 @@ type CloakConfig struct {
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
// CacheUserID controls whether Claude user_id values are cached per API key.
// When false, a fresh random user_id is generated for every request.
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
@@ -368,6 +385,9 @@ type CodexKey struct {
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
@@ -743,22 +763,24 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
return
}
// Inject default Kiro aliases if no user-configured kiro aliases exist
// Inject channel defaults when the channel is absent in user config.
// Presence is checked case-insensitively and includes explicit nil/empty markers.
if cfg.OAuthModelAlias == nil {
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
}
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
// Check case-insensitive too
found := false
hasChannel := func(channel string) bool {
for k := range cfg.OAuthModelAlias {
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
found = true
break
if strings.EqualFold(strings.TrimSpace(k), channel) {
return true
}
}
if !found {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
return false
}
if !hasChannel("kiro") {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
if !hasChannel("github-copilot") {
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
}
if len(cfg.OAuthModelAlias) == 0 {

View File

@@ -42,6 +42,21 @@ func defaultKiroAliases() []OAuthModelAlias {
}
}
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
// This keeps compatibility with clients (e.g. Claude Code) that use
// Anthropic-style model IDs like "claude-opus-4-6".
func defaultGitHubCopilotAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
}
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {

View File

@@ -107,6 +107,44 @@ func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) == 0 {
t.Fatal("expected default github-copilot aliases to be injected")
}
aliasSet := make(map[string]bool, len(copilotAliases))
for _, a := range copilotAliases {
aliasSet[a.Alias] = true
if !a.Fork {
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
expectedAliases := []string{
"claude-haiku-4-5",
"claude-opus-4-1",
"claude-opus-4-5",
"claude-opus-4-6",
"claude-sonnet-4-5",
"claude-sonnet-4-6",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default github-copilot alias %q to be present", expected)
}
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
// When user has configured kiro aliases, defaults should NOT be injected
cfg := &Config{
@@ -128,6 +166,26 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": {
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 1 {
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
}
if copilotAliases[0].Alias != "my-opus" {
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
// When user explicitly deletes kiro aliases (key exists with nil value),
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
@@ -154,6 +212,24 @@ func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing
}
}
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": nil, // explicitly deleted
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 0 {
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
}
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
// Same as above but with empty slice instead of nil (PUT with empty body).
cfg := &Config{

View File

@@ -20,6 +20,10 @@ type SDKConfig struct {
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`

View File

@@ -27,4 +27,7 @@ const (
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
Kiro = "kiro"
// Kilo represents the Kilo AI provider identifier.
Kilo = "kilo"
)

View File

@@ -0,0 +1,21 @@
// Package registry provides model definitions for various AI service providers.
package registry
// GetKiloModels returns the Kilo model definitions
func GetKiloModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kilo/auto",
Object: "model",
Created: 1732752000,
OwnedBy: "kilo",
Type: "kilo",
DisplayName: "Kilo Auto",
Description: "Automatic model selection by Kilo",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -21,6 +21,7 @@ import (
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - kiro
// - amazonq
@@ -50,6 +51,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetGitHubCopilotModels()
case "kiro":
return GetKiroModels()
case "kilo":
return GetKiloModels()
case "amazonq":
return GetAmazonQModels()
case "antigravity":
@@ -99,6 +102,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetKimiModels(),
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
}
for _, models := range allModels {
@@ -125,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {
now := int64(1732752000) // 2024-11-27
return []*ModelInfo{
gpt4oEntries := []struct {
ID string
DisplayName string
Description string
}{
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
}
models := []*ModelInfo{
{
ID: "gpt-4.1",
Object: "model",
@@ -137,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
}
for _, entry := range gpt4oEntries {
models = append(models, &ModelInfo{
ID: entry.ID,
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: entry.DisplayName,
Description: entry.Description,
ContextLength: 128000,
MaxCompletionTokens: 16384,
})
}
return append(models, []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
@@ -254,6 +287,19 @@ func GetGitHubCopilotModels() []*ModelInfo {
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.3 Codex",
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",
@@ -326,6 +372,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4.6",
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
@@ -348,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -382,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}
}...)
}
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
@@ -413,6 +482,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6",
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5",
Object: "model",
@@ -555,6 +636,18 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6-agentic",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5-agentic",
Object: "model",

View File

@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-6",
Object: "model",
@@ -40,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 128000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771286400, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
Description: "Best combination of speed and intelligence",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
@@ -173,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -283,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -425,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -506,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: 1771459200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-pro-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Pro Preview",
Description: "Gemini 3.1 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
@@ -892,11 +975,14 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}

View File

@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, &param)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
}
}(firstEvent)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.

View File

@@ -232,7 +232,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -436,7 +436,7 @@ attemptLoop:
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
}
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -775,7 +775,6 @@ attemptLoop:
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
@@ -820,7 +819,7 @@ attemptLoop:
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
switch {
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
"time"
@@ -116,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -134,7 +135,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -143,7 +144,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -208,7 +209,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
@@ -222,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data,
&param,
)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -257,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
@@ -275,7 +276,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -284,7 +285,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -329,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -348,7 +348,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
@@ -375,7 +375,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
@@ -398,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -423,7 +423,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) {
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
@@ -432,7 +432,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
@@ -487,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
@@ -638,7 +638,49 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
return body, nil
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
@@ -685,16 +727,17 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", "v24.3.0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", "0.55.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", "arm64")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", "MacOS")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", "60")
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "claude-cli/1.0.83 (external, cli)")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -702,6 +745,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
} else {
r.Header.Set("Accept", "application/json")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
@@ -753,11 +798,21 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
@@ -772,7 +827,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) {
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
@@ -784,15 +839,38 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
return true
@@ -811,15 +889,38 @@ func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
if part.Get("type").String() != "tool_use" {
return true
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
return true
})
return body
@@ -834,15 +935,34 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() || contentBlock.Get("type").String() != "tool_use" {
if !contentBlock.Exists() {
return line
}
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err := sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}
@@ -862,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil
return "auto", false, nil, false
}
cloakMode := auth.Attributes["cloak_mode"]
@@ -883,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
}
}
return cloakMode, strictMode, sensitiveWords
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
@@ -916,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
func injectFakeUserID(payload []byte) []byte {
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
}
return generateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
@@ -962,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
@@ -972,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
@@ -989,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
@@ -1002,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
}
// Inject fake user ID
payload = injectFakeUserID(payload)
payload = injectFakeUserID(payload, apiKey, cacheUserID)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {

View File

@@ -2,9 +2,18 @@ package executor
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
@@ -25,6 +34,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
@@ -37,6 +58,97 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -49,6 +161,18 @@ func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
@@ -61,3 +185,166 @@ func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userID := gjson.GetBytes(body, "metadata.user_id").String()
model := gjson.GetBytes(body, "model").String()
userIDs = append(userIDs, userID)
requestModels = append(requestModels, model)
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
cacheEnabled := true
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "key-123",
BaseURL: server.URL,
Cloak: &config.CloakConfig{
CacheUserID: &cacheEnabled,
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
for _, model := range models {
t.Logf("Sending request for model: %s", model)
modelPayload, _ := sjson.SetBytes(payload, "model", model)
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: model,
Payload: modelPayload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute(%s) error: %v", model, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
for i := 0; i < 2; i++ {
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute call %d error: %v", i, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}

View File

@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
@@ -643,7 +642,6 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)

File diff suppressed because it is too large Load Diff

View File

@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
if len(lastBody) > 0 {
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)

View File

@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).

View File

@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.

View File

@@ -39,7 +39,8 @@ const (
copilotEditorVersion = "vscode/1.107.0"
copilotPluginVersion = "copilot-chat/0.35.0"
copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-edits"
copilotOpenAIIntent = "conversation-panel"
copilotGitHubAPIVer = "2025-04-01"
)
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
@@ -51,8 +52,9 @@ type GitHubCopilotExecutor struct {
// cachedAPIToken stores a cached Copilot API token with its expiry.
type cachedAPIToken struct {
token string
expiresAt time.Time
token string
apiEndpoint string
expiresAt time.Time
}
// NewGitHubCopilotExecutor constructs a new executor instance.
@@ -75,7 +77,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
if ctx == nil {
ctx = context.Background()
}
apiToken, errToken := e.ensureAPIToken(ctx, auth)
apiToken, _, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return errToken
}
@@ -101,7 +103,7 @@ func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxya
// Execute handles non-streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return resp, errToken
}
@@ -124,6 +126,9 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
@@ -147,7 +152,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
url := baseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
@@ -155,7 +160,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
if hasVision {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
@@ -227,8 +232,8 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
}
// ExecuteStream handles streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return nil, errToken
}
@@ -251,6 +256,9 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
@@ -278,7 +286,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
url := baseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -286,7 +294,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
if hasVision {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
@@ -333,7 +341,6 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
@@ -386,7 +393,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// CountTokens is not supported for GitHub Copilot.
@@ -418,22 +428,22 @@ func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.
}
// ensureAPIToken gets or refreshes the Copilot API token.
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, string, error) {
if auth == nil {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
// Get the GitHub access token
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
}
// Check for cached API token using thread-safe access
e.mu.RLock()
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
e.mu.RUnlock()
return cached.token, nil
return cached.token, cached.apiEndpoint, nil
}
e.mu.RUnlock()
@@ -441,7 +451,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
if err != nil {
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
return "", "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
}
// Use endpoint from token response, fall back to default
apiEndpoint := githubCopilotBaseURL
if apiToken.Endpoints.API != "" {
apiEndpoint = strings.TrimRight(apiToken.Endpoints.API, "/")
}
// Cache the token with thread-safe access
@@ -451,12 +467,13 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
}
e.mu.Lock()
e.cache[accessToken] = &cachedAPIToken{
token: apiToken.Token,
expiresAt: expiresAt,
token: apiToken.Token,
apiEndpoint: apiEndpoint,
expiresAt: expiresAt,
}
e.mu.Unlock()
return apiToken.Token, nil
return apiToken.Token, apiEndpoint, nil
}
// applyHeaders sets the required headers for GitHub Copilot API requests.
@@ -469,16 +486,17 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer)
r.Header.Set("X-Request-Id", uuid.NewString())
initiator := "user"
if len(body) > 0 {
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
arr := messages.Array()
if len(arr) > 0 {
lastRole := arr[len(arr)-1].Get("role").String()
if lastRole != "" && lastRole != "user" {
for _, msg := range messages.Array() {
role := msg.Get("role").String()
if role == "assistant" || role == "tool" {
initiator = "agent"
break
}
}
}
@@ -550,6 +568,17 @@ func flattenAssistantContent(body []byte) []byte {
if !content.Exists() || !content.IsArray() {
continue
}
// Skip flattening if the content contains non-text blocks (tool_use, thinking, etc.)
hasNonText := false
for _, part := range content.Array() {
if t := part.Get("type").String(); t != "" && t != "text" {
hasNonText = true
break
}
}
if hasNonText {
continue
}
var textParts []string
for _, part := range content.Array() {
if part.Get("type").String() == "text" {
@@ -597,31 +626,173 @@ func normalizeGitHubCopilotChatTools(body []byte) []byte {
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if input.Exists() {
if input.Type == gjson.String {
// If input is already a string or array, keep it as-is.
if input.Type == gjson.String || input.IsArray() {
return body
}
inputString := input.Raw
if input.Type != gjson.JSON {
inputString = input.String()
}
body, _ = sjson.SetBytes(body, "input", inputString)
// Non-string/non-array input: stringify as fallback.
body, _ = sjson.SetBytes(body, "input", input.Raw)
return body
}
var parts []string
// Convert Claude messages format to OpenAI Responses API input array.
// This preserves the conversation structure (roles, tool calls, tool results)
// which is critical for multi-turn tool-use conversations.
inputArr := "[]"
// System messages → developer role
if system := gjson.GetBytes(body, "system"); system.Exists() {
if text := strings.TrimSpace(collectTextFromNode(system)); text != "" {
parts = append(parts, text)
var systemParts []string
if system.IsArray() {
for _, part := range system.Array() {
if txt := part.Get("text").String(); txt != "" {
systemParts = append(systemParts, txt)
}
}
} else if system.Type == gjson.String {
systemParts = append(systemParts, system.String())
}
if len(systemParts) > 0 {
msg := `{"type":"message","role":"developer","content":[]}`
for _, txt := range systemParts {
part := `{"type":"input_text","text":""}`
part, _ = sjson.Set(part, "text", txt)
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", msg)
}
}
// Messages → structured input items
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" {
parts = append(parts, text)
role := msg.Get("role").String()
content := msg.Get("content")
if !content.Exists() {
continue
}
// Simple string content
if content.Type == gjson.String {
textType := "input_text"
if role == "assistant" {
textType = "output_text"
}
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
part, _ = sjson.Set(part, "text", content.String())
item, _ = sjson.SetRaw(item, "content.-1", part)
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
continue
}
if !content.IsArray() {
continue
}
// Array content: split into message parts vs tool items
var msgParts []string
for _, c := range content.Array() {
cType := c.Get("type").String()
switch cType {
case "text":
textType := "input_text"
if role == "assistant" {
textType = "output_text"
}
part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
part, _ = sjson.Set(part, "text", c.Get("text").String())
msgParts = append(msgParts, part)
case "image":
source := c.Get("source")
if source.Exists() {
data := source.Get("data").String()
if data == "" {
data = source.Get("base64").String()
}
mediaType := source.Get("media_type").String()
if mediaType == "" {
mediaType = source.Get("mime_type").String()
}
if mediaType == "" {
mediaType = "application/octet-stream"
}
if data != "" {
part := `{"type":"input_image","image_url":""}`
part, _ = sjson.Set(part, "image_url", fmt.Sprintf("data:%s;base64,%s", mediaType, data))
msgParts = append(msgParts, part)
}
}
case "tool_use":
// Flush any accumulated message parts first
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
msgParts = nil
}
fc := `{"type":"function_call","call_id":"","name":"","arguments":""}`
fc, _ = sjson.Set(fc, "call_id", c.Get("id").String())
fc, _ = sjson.Set(fc, "name", c.Get("name").String())
if inputRaw := c.Get("input"); inputRaw.Exists() {
fc, _ = sjson.Set(fc, "arguments", inputRaw.Raw)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", fc)
case "tool_result":
// Flush any accumulated message parts first
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
msgParts = nil
}
fco := `{"type":"function_call_output","call_id":"","output":""}`
fco, _ = sjson.Set(fco, "call_id", c.Get("tool_use_id").String())
// Extract output text
resultContent := c.Get("content")
if resultContent.Type == gjson.String {
fco, _ = sjson.Set(fco, "output", resultContent.String())
} else if resultContent.IsArray() {
var resultParts []string
for _, rc := range resultContent.Array() {
if txt := rc.Get("text").String(); txt != "" {
resultParts = append(resultParts, txt)
}
}
fco, _ = sjson.Set(fco, "output", strings.Join(resultParts, "\n"))
} else if resultContent.Exists() {
fco, _ = sjson.Set(fco, "output", resultContent.String())
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", fco)
case "thinking":
// Skip thinking blocks - not part of the API input
}
}
// Flush remaining message parts
if len(msgParts) > 0 {
item := `{"type":"message","role":"","content":[]}`
item, _ = sjson.Set(item, "role", role)
for _, p := range msgParts {
item, _ = sjson.SetRaw(item, "content.-1", p)
}
inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
}
}
}
body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n"))
body, _ = sjson.SetRawBytes(body, "input", []byte(inputArr))
// Remove messages/system since we've converted them to input
body, _ = sjson.DeleteBytes(body, "messages")
body, _ = sjson.DeleteBytes(body, "system")
return body
}
@@ -747,6 +918,8 @@ type githubCopilotResponsesStreamState struct {
TextBlockIndex int
NextContentIndex int
HasToolUse bool
ReasoningActive bool
ReasoningIndex int
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
@@ -761,6 +934,33 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
if output := root.Get("output"); output.Exists() && output.IsArray() {
for _, item := range output.Array() {
switch item.Get("type").String() {
case "reasoning":
var thinkingText string
if summary := item.Get("summary"); summary.Exists() && summary.IsArray() {
var parts []string
for _, part := range summary.Array() {
if txt := part.Get("text").String(); txt != "" {
parts = append(parts, txt)
}
}
thinkingText = strings.Join(parts, "")
}
if thinkingText == "" {
if content := item.Get("content"); content.Exists() && content.IsArray() {
var parts []string
for _, part := range content.Array() {
if txt := part.Get("text").String(); txt != "" {
parts = append(parts, txt)
}
}
thinkingText = strings.Join(parts, "")
}
}
if thinkingText != "" {
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() && content.IsArray() {
for _, part := range content.Array() {
@@ -798,10 +998,19 @@ func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
inputTokens := root.Get("usage.input_tokens").Int()
outputTokens := root.Get("usage.output_tokens").Int()
cachedTokens := root.Get("usage.input_tokens_details.cached_tokens").Int()
if cachedTokens > 0 && inputTokens >= cachedTokens {
inputTokens -= cachedTokens
}
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
}
if hasToolUse {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else if sr := root.Get("stop_reason").String(); sr == "max_tokens" || sr == "stop" {
out, _ = sjson.Set(out, "stop_reason", sr)
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
@@ -892,6 +1101,31 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
}
case "response.reasoning_summary_part.added":
ensureMessageStart()
state.ReasoningActive = true
state.ReasoningIndex = state.NextContentIndex
state.NextContentIndex++
thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
case "response.reasoning_summary_text.delta":
if state.ReasoningActive {
delta := gjson.GetBytes(payload, "delta").String()
if delta != "" {
thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
}
}
case "response.reasoning_summary_part.done":
if state.ReasoningActive {
thinkingStop := `{"type":"content_block_stop","index":0}`
thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
state.ReasoningActive = false
}
case "response.output_item.added":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
@@ -938,6 +1172,23 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.function_call_arguments.delta":
// Copilot sends tool call arguments via this event type (not response.output_item.delta).
// Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
itemID := gjson.GetBytes(payload, "item_id").String()
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
tool := resolveTool(itemID, outputIndex)
if tool == nil {
break
}
partial := gjson.GetBytes(payload, "delta").String()
if partial == "" {
break
}
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
@@ -956,11 +1207,22 @@ func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []st
stopReason := "end_turn"
if state.HasToolUse {
stopReason = "tool_use"
} else if sr := gjson.GetBytes(payload, "response.stop_reason").String(); sr == "max_tokens" || sr == "stop" {
stopReason = sr
}
inputTokens := gjson.GetBytes(payload, "response.usage.input_tokens").Int()
outputTokens := gjson.GetBytes(payload, "response.usage.output_tokens").Int()
cachedTokens := gjson.GetBytes(payload, "response.usage.input_tokens_details.cached_tokens").Int()
if cachedTokens > 0 && inputTokens >= cachedTokens {
inputTokens -= cachedTokens
}
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int())
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int())
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", inputTokens)
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
}
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true

View File

@@ -1,6 +1,7 @@
package executor
import (
"net/http"
"strings"
"testing"
@@ -103,11 +104,18 @@ func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAnd
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
if !in.IsArray() {
t.Fatalf("input type = %v, want array", in.Type)
}
if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") {
t.Fatalf("input = %q, want merged text", in.String())
raw := in.Raw
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
t.Fatalf("input = %s, want structured array with all texts", raw)
}
if gjson.GetBytes(got, "messages").Exists() {
t.Fatal("messages should be removed after conversion")
}
if gjson.GetBytes(got, "system").Exists() {
t.Fatal("system should be removed after conversion")
}
}
@@ -240,3 +248,86 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}
// --- Tests for X-Initiator detection logic (Problem L) ---
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Claude Code typical flow: last message is user (tool result), but has assistant in history
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
}
}
// --- Tests for vision detection (Problem P) ---
func TestDetectVisionContent_WithImageURL(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected vision content to be detected")
}
}
func TestDetectVisionContent_WithImageType(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected image type to be detected")
}
}
func TestDetectVisionContent_NoVision(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content")
}
}
func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Parallel()
// After Responses API normalization, messages is removed — detection should return false
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content when messages field is absent")
}
}

View File

@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -0,0 +1,460 @@
package executor
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// KiloExecutor handles requests to Kilo API.
type KiloExecutor struct {
cfg *config.Config
}
// NewKiloExecutor creates a new Kilo executor instance.
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
return &KiloExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *KiloExecutor) Identifier() string { return "kilo" }
// PrepareRequest prepares the HTTP request before execution.
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, _ := kiloCredentials(auth)
if strings.TrimSpace(accessToken) == "" {
return fmt.Errorf("kilo: missing access token")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("kilo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer httpResp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer httpResp.Body.Close()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh validates the Kilo token.
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("missing auth")
}
return auth, nil
}
// CountTokens returns the token count for the given request.
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
}
// kiloCredentials extracts access token and other info from auth.
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
if auth == nil {
return "", ""
}
// Prefer kilocode specific keys, then fall back to generic keys.
// Check metadata first, then attributes.
if auth.Metadata != nil {
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
accessToken = token
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
accessToken = token
}
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
orgID = org
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
orgID = org
}
}
if accessToken == "" && auth.Attributes != nil {
if token := auth.Attributes["kilocodeToken"]; token != "" {
accessToken = token
} else if token := auth.Attributes["access_token"]; token != "" {
accessToken = token
}
}
if orgID == "" && auth.Attributes != nil {
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
orgID = org
} else if org := auth.Attributes["organization_id"]; org != "" {
orgID = org
}
}
return accessToken, orgID
}
// FetchKiloModels fetches models from Kilo API.
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
return registry.GetKiloModels()
}
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
if err != nil {
log.Warnf("kilo: failed to create model fetch request: %v", err)
return registry.GetKiloModels()
}
req.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
req.Header.Set("X-Kilocode-OrganizationID", orgID)
}
req.Header.Set("User-Agent", "cli-proxy-kilo")
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Warnf("kilo: fetch models canceled: %v", err)
} else {
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
}
return registry.GetKiloModels()
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Warnf("kilo: failed to read models response: %v", err)
return registry.GetKiloModels()
}
if resp.StatusCode != http.StatusOK {
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
return registry.GetKiloModels()
}
result := gjson.GetBytes(body, "data")
if !result.Exists() {
// Try root if data field is missing
result = gjson.ParseBytes(body)
if !result.IsArray() {
log.Debugf("kilo: response body: %s", string(body))
log.Warn("kilo: invalid API response format (expected array or data field with array)")
return registry.GetKiloModels()
}
}
var dynamicModels []*registry.ModelInfo
now := time.Now().Unix()
count := 0
totalCount := 0
result.ForEach(func(key, value gjson.Result) bool {
totalCount++
id := value.Get("id").String()
pIdxResult := value.Get("preferredIndex")
preferredIndex := pIdxResult.Int()
// Filter models where preferredIndex > 0 (Kilo-curated models)
if preferredIndex <= 0 {
return true
}
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
if !isFree {
// Check pricing as fallback
promptPricing := value.Get("pricing.prompt").String()
if promptPricing == "0" || promptPricing == "0.0" {
isFree = true
}
}
if !isFree {
log.Debugf("kilo: skipping curated paid model: %s", id)
return true
}
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
dynamicModels = append(dynamicModels, &registry.ModelInfo{
ID: id,
DisplayName: value.Get("name").String(),
ContextLength: int(value.Get("context_length").Int()),
OwnedBy: "kilo",
Type: "kilo",
Object: "model",
Created: now,
})
count++
return true
})
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
if count == 0 && totalCount > 0 {
log.Warn("kilo: no curated free models found (check API response fields)")
}
staticModels := registry.GetKiloModels()
// Always include kilo/auto (first static model)
allModels := append(staticModels[:1], dynamicModels...)
return allModels
}

View File

@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request to Kimi.
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for Kimi requests.

View File

@@ -1053,7 +1053,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
// ExecuteStream handles streaming requests to Kiro API.
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
accessToken, profileArn := kiroCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("kiro: access token not found in auth")
@@ -1110,7 +1110,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
// Route to MCP endpoint instead of normal Kiro API
if kiroclaude.HasWebSearchTool(req.Payload) {
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
if errWebSearch != nil {
return nil, errWebSearch
}
return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
@@ -1128,7 +1132,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
// Execute stream with retry on 401/403 and 429 (quota exhausted)
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
if errStreamKiro != nil {
return nil, errStreamKiro
}
return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil
}
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
@@ -1709,6 +1717,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
// Amazon Q format (amazonq- prefix) - same API as Kiro
"amazonq-auto": "auto",
"amazonq-claude-opus-4-6": "claude-opus-4.6",
"amazonq-claude-sonnet-4-6": "claude-sonnet-4.6",
"amazonq-claude-opus-4-5": "claude-opus-4.5",
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
@@ -1717,6 +1726,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
// Kiro format (kiro- prefix) - valid model names that should be preserved
"kiro-claude-opus-4-6": "claude-opus-4.6",
"kiro-claude-sonnet-4-6": "claude-sonnet-4.6",
"kiro-claude-opus-4-5": "claude-opus-4.5",
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
@@ -1727,6 +1737,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
// Native format (no prefix) - used by Kiro IDE directly
"claude-opus-4-6": "claude-opus-4.6",
"claude-opus-4.6": "claude-opus-4.6",
"claude-sonnet-4-6": "claude-sonnet-4.6",
"claude-sonnet-4.6": "claude-sonnet-4.6",
"claude-opus-4-5": "claude-opus-4.5",
"claude-opus-4.5": "claude-opus-4.5",
"claude-haiku-4-5": "claude-haiku-4.5",
@@ -1739,11 +1751,13 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
"auto": "auto",
// Agentic variants (same backend model IDs, but with special system prompt)
"claude-opus-4.6-agentic": "claude-opus-4.6",
"claude-sonnet-4.6-agentic": "claude-sonnet-4.6",
"claude-opus-4.5-agentic": "claude-opus-4.5",
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
"claude-sonnet-4-agentic": "claude-sonnet-4",
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
"kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6",
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
@@ -1769,6 +1783,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
return "claude-3-7-sonnet-20250219"
}
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model)
return "claude-sonnet-4.6"
}
if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
return "claude-sonnet-4.5"
@@ -1780,6 +1798,10 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
// Check for Opus variants
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model)
return "claude-opus-4.6"
}
log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
return "claude-opus-4.5"
}
@@ -2529,6 +2551,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
thinkingBlockIndex := -1 // Index of the thinking content block
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
hasOfficialReasoningEvent := false // Disable tag parsing after official reasoning events appear
// Buffer for handling partial tag matches at chunk boundaries
var pendingContent strings.Builder // Buffer content that might be part of a tag
@@ -2964,6 +2987,31 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
lastUsageUpdateTime = time.Now()
}
if hasOfficialReasoningEvent {
processText := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(contentDelta, kirocommon.ThinkingStartTag, ""), kirocommon.ThinkingEndTag, ""))
if processText != "" {
if !isTextBlockOpen {
contentBlockIndex++
isTextBlockOpen = true
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
}
}
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
}
}
continue
}
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
// Combine pending content with new content for processing
pendingContent.WriteString(contentDelta)
@@ -3242,6 +3290,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
}
if thinkingText != "" {
hasOfficialReasoningEvent = true
// Close text block if open before starting thinking block
if isTextBlockOpen && contentBlockIndex >= 0 {
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)

View File

@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -150,11 +150,11 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
@@ -236,7 +236,6 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
@@ -268,7 +267,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {

View File

@@ -0,0 +1,89 @@
package executor
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
type userIDCacheEntry struct {
value string
expire time.Time
}
var (
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu sync.RWMutex
userIDCacheCleanupOnce sync.Once
)
const (
userIDTTL = time.Hour
userIDCacheCleanupPeriod = 15 * time.Minute
)
func startUserIDCacheCleanup() {
go func() {
ticker := time.NewTicker(userIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredUserIDs()
}
}()
}
func purgeExpiredUserIDs() {
now := time.Now()
userIDCacheMu.Lock()
for key, entry := range userIDCache {
if !entry.expire.After(now) {
delete(userIDCache, key)
}
}
userIDCacheMu.Unlock()
}
func userIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
key := userIDCacheKey(apiKey)
now := time.Now()
userIDCacheMu.RLock()
entry, ok := userIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
userIDCacheMu.RUnlock()
if valid {
userIDCacheMu.Lock()
entry = userIDCache[key]
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}
userIDCacheMu.Unlock()
}
newID := generateFakeUserID()
userIDCacheMu.Lock()
entry, ok = userIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
entry.value = newID
}
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}

View File

@@ -0,0 +1,86 @@
package executor
import (
"testing"
"time"
)
func resetUserIDCache() {
userIDCacheMu.Lock()
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu.Unlock()
}
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
}
if first != second {
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
}
}
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: expiredID,
expire: time.Now().Add(-time.Minute),
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
if newID == "" {
t.Fatal("expected regenerated user_id to be non-empty")
}
}
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
}
}
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: id,
expire: soon.Add(2 * time.Second),
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}
userIDCacheMu.RLock()
entry := userIDCache[cacheKey]
userIDCacheMu.RUnlock()
if entry.expire.Sub(soon) < 30*time.Minute {
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
}
}

View File

@@ -10,10 +10,53 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// validReasoningEffortLevels contains the standard values accepted by the
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
// auto) are NOT in this set and must be clamped before use.
var validReasoningEffortLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
}
// clampReasoningEffort maps any thinking level string to a value that is safe
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
// mapped to the nearest standard equivalent.
//
// Mapping rules:
// - none / low / medium / high → returned as-is (already valid)
// - xhigh → "high" (nearest lower standard level)
// - minimal → "low" (nearest higher standard level)
// - auto → "medium" (reasonable default)
// - anything else → "medium" (safe default)
func clampReasoningEffort(level string) string {
if _, ok := validReasoningEffortLevels[level]; ok {
return level
}
var clamped string
switch level {
case string(thinking.LevelXHigh):
clamped = string(thinking.LevelHigh)
case string(thinking.LevelMinimal):
clamped = string(thinking.LevelLow)
case string(thinking.LevelAuto):
clamped = string(thinking.LevelMedium)
default:
clamped = string(thinking.LevelMedium)
}
log.WithFields(log.Fields{
"original": level,
"clamped": clamped,
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
return clamped
}
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
@@ -58,7 +101,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
return result, nil
}
@@ -79,7 +122,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
return result, nil
}
@@ -114,7 +157,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
return result, nil
}

View File

@@ -231,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if functionResponseResult.IsObject() {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
partJSON := `{}`

View File

@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
// Bug repro: tool_result with no content field produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
// Verify the functionResponse has a valid result value
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
if !fr.Exists() {
t.Error("functionResponse.response.result should exist")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
// Bug repro: tool_result with null content produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456",
"content": null
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{

View File

@@ -22,8 +22,9 @@ var (
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
}
}
return []string{output}

View File

@@ -20,10 +20,12 @@ var (
// ConvertCliToOpenAIParams holds parameters for response conversion.
type ConvertCliToOpenAIParams struct {
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
HasReceivedArgumentsDelta bool
HasToolCallAnnounced bool
}
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
@@ -43,10 +45,12 @@ type ConvertCliToOpenAIParams struct {
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
HasReceivedArgumentsDelta: false,
HasToolCallAnnounced: false,
}
}
@@ -118,35 +122,93 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.done" {
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if itemResult.Exists() {
if itemResult.Get("type").String() != "function_call" {
return []string{}
}
// set the index
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened
name := itemResult.Get("name").String()
// Build reverse map on demand from original request tools
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
// Increment index for this new function call item.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
}

View File

@@ -243,13 +243,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
// Process messages and build history
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
// Build content with system prompt (only on first turn to avoid re-injection)
// Build content with system prompt.
// Keep thinking tags on subsequent turns so multi-turn Claude sessions
// continue to emit reasoning events.
if currentUserMsg != nil {
effectiveSystemPrompt := systemPrompt
if len(history) > 0 {
effectiveSystemPrompt = "" // Don't re-inject on subsequent turns
}
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, effectiveSystemPrompt, currentToolResults)
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
// Deduplicate currentToolResults
currentToolResults = deduplicateToolResults(currentToolResults)
@@ -475,6 +473,15 @@ func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
}
}
// Check model name directly for thinking hints.
// This enables thinking variants even when clients don't send explicit thinking fields.
model := strings.TrimSpace(gjson.GetBytes(body, "model").String())
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
return true
}
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
return false
}

View File

@@ -234,16 +234,16 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
// rather than inline <thinking> tags in assistantResponseEvent.
// We use a high max_thinking_length to allow extensive reasoning.
// Use a conservative thinking budget to reduce latency/cost spikes in long sessions.
if thinkingEnabled {
thinkingHint := `<thinking_mode>enabled</thinking_mode>
<max_thinking_length>200000</max_thinking_length>`
<max_thinking_length>16000</max_thinking_length>`
if systemPrompt != "" {
systemPrompt = thinkingHint + "\n\n" + systemPrompt
} else {
systemPrompt = thinkingHint
}
log.Debugf("kiro-openai: injected thinking prompt (official mode)")
log.Infof("kiro-openai: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
}
// Process messages and build history
@@ -831,7 +831,6 @@ func hasThinkingTagInBody(body []byte) bool {
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
}
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
// OpenAI tool_choice values:
// - "none": Don't use any tools

542
internal/tui/app.go Normal file
View File

@@ -0,0 +1,542 @@
package tui
import (
"fmt"
"io"
"os"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// Tab identifiers
const (
tabDashboard = iota
tabConfig
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
// App is the root bubbletea model that contains all tab sub-models.
type App struct {
activeTab int
tabs []string
standalone bool
logsEnabled bool
authenticated bool
authInput textinput.Model
authError string
authConnecting bool
dashboard dashboardModel
config configTabModel
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
width int
height int
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
}
type authConnectMsg struct {
cfg map[string]any
err error
}
// NewApp creates the root TUI application model.
func NewApp(port int, secretKey string, hook *LogHook) App {
standalone := hook != nil
authRequired := !standalone
ti := textinput.New()
ti.CharLimit = 512
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '*'
ti.SetValue(strings.TrimSpace(secretKey))
ti.Focus()
client := NewClient(port, secretKey)
app := App{
activeTab: tabDashboard,
standalone: standalone,
logsEnabled: true,
authenticated: !authRequired,
authInput: ti,
dashboard: newDashboardModel(client),
config: newConfigTabModel(client),
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
tabDashboard: true,
tabLogs: true,
},
}
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
}
app.setAuthInputPrompt()
return app
}
func (a App) Init() tea.Cmd {
if !a.authenticated {
return textinput.Blink
}
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
cmds = append(cmds, a.logs.Init())
}
return tea.Batch(cmds...)
}
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
a.width = msg.Width
a.height = msg.Height
a.ready = true
if a.width > 0 {
a.authInput.Width = a.width - 6
}
contentH := a.height - 4 // tab bar + status bar
if contentH < 1 {
contentH = 1
}
contentW := a.width
a.dashboard.SetSize(contentW, contentH)
a.config.SetSize(contentW, contentH)
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
case authConnectMsg:
a.authConnecting = false
if msg.err != nil {
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
return a, nil
}
a.authError = ""
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
a.initialized[tabLogs] = true
cmds = append(cmds, a.logs.Init())
}
return a, tea.Batch(cmds...)
case configUpdateMsg:
var cmdLogs tea.Cmd
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
logsEnabledConfig, okConfig := msg.value.(bool)
if okConfig {
logsEnabledBefore := a.logsEnabled
a.logsEnabled = logsEnabledConfig
if logsEnabledBefore != a.logsEnabled {
a.refreshTabs()
}
if !a.logsEnabled {
a.initialized[tabLogs] = false
}
if !logsEnabledBefore && a.logsEnabled {
a.initialized[tabLogs] = true
cmdLogs = a.logs.Init()
}
}
}
var cmdConfig tea.Cmd
a.config, cmdConfig = a.config.Update(msg)
if cmdConfig != nil && cmdLogs != nil {
return a, tea.Batch(cmdConfig, cmdLogs)
}
if cmdConfig != nil {
return a, cmdConfig
}
return a, cmdLogs
case tea.KeyMsg:
if !a.authenticated {
switch msg.String() {
case "ctrl+c", "q":
return a, tea.Quit
case "L":
ToggleLocale()
a.refreshTabs()
a.setAuthInputPrompt()
return a, nil
case "enter":
if a.authConnecting {
return a, nil
}
password := strings.TrimSpace(a.authInput.Value())
if password == "" {
a.authError = T("auth_gate_password_required")
return a, nil
}
a.authError = ""
a.authConnecting = true
return a, a.connectWithPassword(password)
default:
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
}
switch msg.String() {
case "ctrl+c":
return a, tea.Quit
case "q":
// Only quit if not in logs tab (where 'q' might be useful)
if !a.logsEnabled || a.activeTab != tabLogs {
return a, tea.Quit
}
case "L":
ToggleLocale()
a.refreshTabs()
return a.broadcastToAllTabs(localeChangedMsg{})
case "tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab + 1) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
case "shift+tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
}
}
if !a.authenticated {
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
// Route msg to active tab
var cmd tea.Cmd
switch a.activeTab {
case tabDashboard:
a.dashboard, cmd = a.dashboard.Update(msg)
case tabConfig:
a.config, cmd = a.config.Update(msg)
case tabAuthFiles:
a.auth, cmd = a.auth.Update(msg)
case tabAPIKeys:
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
// Keep logs polling alive even when logs tab is not active.
if a.logsEnabled && a.activeTab != tabLogs {
switch msg.(type) {
case logsPollMsg, logsTickMsg, logLineMsg:
var logCmd tea.Cmd
a.logs, logCmd = a.logs.Update(msg)
if logCmd != nil {
cmd = logCmd
}
}
}
return a, cmd
}
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
type localeChangedMsg struct{}
func (a *App) refreshTabs() {
names := TabNames()
if a.logsEnabled {
a.tabs = names
} else {
filtered := make([]string, 0, len(names)-1)
for idx, name := range names {
if idx == tabLogs {
continue
}
filtered = append(filtered, name)
}
a.tabs = filtered
}
if len(a.tabs) == 0 {
a.activeTab = tabDashboard
return
}
if a.activeTab >= len(a.tabs) {
a.activeTab = len(a.tabs) - 1
}
}
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
if a.initialized[a.activeTab] {
return nil
}
a.initialized[a.activeTab] = true
switch a.activeTab {
case tabDashboard:
return a.dashboard.Init()
case tabConfig:
return a.config.Init()
case tabAuthFiles:
return a.auth.Init()
case tabAPIKeys:
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
}
return a.logs.Init()
}
return nil
}
func (a App) View() string {
if !a.authenticated {
return a.renderAuthView()
}
if !a.ready {
return T("initializing_tui")
}
var sb strings.Builder
// Tab bar
sb.WriteString(a.renderTabBar())
sb.WriteString("\n")
// Content
switch a.activeTab {
case tabDashboard:
sb.WriteString(a.dashboard.View())
case tabConfig:
sb.WriteString(a.config.View())
case tabAuthFiles:
sb.WriteString(a.auth.View())
case tabAPIKeys:
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
}
}
// Status bar
sb.WriteString("\n")
sb.WriteString(a.renderStatusBar())
return sb.String()
}
func (a App) renderAuthView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
sb.WriteString("\n\n")
if a.authConnecting {
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
sb.WriteString("\n\n")
}
if strings.TrimSpace(a.authError) != "" {
sb.WriteString(errorStyle.Render(a.authError))
sb.WriteString("\n\n")
}
sb.WriteString(a.authInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
return sb.String()
}
func (a App) renderTabBar() string {
var tabs []string
for i, name := range a.tabs {
if i == a.activeTab {
tabs = append(tabs, tabActiveStyle.Render(name))
} else {
tabs = append(tabs, tabInactiveStyle.Render(name))
}
}
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
return tabBarStyle.Width(a.width).Render(tabBar)
}
func (a App) renderStatusBar() string {
left := strings.TrimRight(T("status_left"), " ")
right := strings.TrimRight(T("status_right"), " ")
width := a.width
if width < 1 {
width = 1
}
// statusBarStyle has left/right padding(1), so content area is width-2.
contentWidth := width - 2
if contentWidth < 0 {
contentWidth = 0
}
if lipgloss.Width(left) > contentWidth {
left = fitStringWidth(left, contentWidth)
right = ""
}
remaining := contentWidth - lipgloss.Width(left)
if remaining < 0 {
remaining = 0
}
if lipgloss.Width(right) > remaining {
right = fitStringWidth(right, remaining)
}
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
if gap < 0 {
gap = 0
}
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
}
func fitStringWidth(text string, maxWidth int) string {
if maxWidth <= 0 {
return ""
}
if lipgloss.Width(text) <= maxWidth {
return text
}
out := ""
for _, r := range text {
next := out + string(r)
if lipgloss.Width(next) > maxWidth {
break
}
out = next
}
return out
}
func isLogsEnabledFromConfig(cfg map[string]any) bool {
if cfg == nil {
return true
}
value, ok := cfg["logging-to-file"]
if !ok {
return true
}
enabled, ok := value.(bool)
if !ok {
return true
}
return enabled
}
func (a *App) setAuthInputPrompt() {
if a == nil {
return
}
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
}
func (a App) connectWithPassword(password string) tea.Cmd {
return func() tea.Msg {
a.client.SetSecretKey(password)
cfg, errGetConfig := a.client.GetConfig()
return authConnectMsg{cfg: cfg, err: errGetConfig}
}
}
// Run starts the TUI application.
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
if output == nil {
output = os.Stdout
}
app := NewApp(port, secretKey, hook)
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
_, err := p.Run()
return err
}
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
a.dashboard, cmd = a.dashboard.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.config, cmd = a.config.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.auth, cmd = a.auth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.keys, cmd = a.keys.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.oauth, cmd = a.oauth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
return a, tea.Batch(cmds...)
}

456
internal/tui/auth_tab.go Normal file
View File

@@ -0,0 +1,456 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// editableField represents an editable field on an auth file.
type editableField struct {
label string
key string // API field key: "prefix", "proxy_url", "priority"
}
var authEditableFields = []editableField{
{label: "Prefix", key: "prefix"},
{label: "Proxy URL", key: "proxy_url"},
{label: "Priority", key: "priority"},
}
// authTabModel displays auth credential files with interactive management.
type authTabModel struct {
client *Client
viewport viewport.Model
files []map[string]any
err error
width int
height int
ready bool
cursor int
expanded int // -1 = none expanded, >=0 = expanded index
confirm int // -1 = no confirmation, >=0 = confirm delete for index
status string
// Editing state
editing bool // true when editing a field
editField int // index into authEditableFields
editInput textinput.Model // text input for editing
editFileName string // name of file being edited
}
type authFilesMsg struct {
files []map[string]any
err error
}
type authActionMsg struct {
action string // "deleted", "toggled", "updated"
err error
}
func newAuthTabModel(client *Client) authTabModel {
ti := textinput.New()
ti.CharLimit = 256
return authTabModel{
client: client,
expanded: -1,
confirm: -1,
editInput: ti,
}
}
func (m authTabModel) Init() tea.Cmd {
return m.fetchFiles
}
func (m authTabModel) fetchFiles() tea.Msg {
files, err := m.client.GetAuthFiles()
return authFilesMsg{files: files, err: err}
}
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case authFilesMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.files = msg.files
if m.cursor >= len(m.files) {
m.cursor = max(0, len(m.files)-1)
}
m.status = ""
}
m.viewport.SetContent(m.renderContent())
return m, nil
case authActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchFiles
case tea.KeyMsg:
// ---- Editing mode ----
if m.editing {
return m.handleEditInput(msg)
}
// ---- Delete confirmation mode ----
if m.confirm >= 0 {
return m.handleConfirmInput(msg)
}
// ---- Normal mode ----
return m.handleNormalInput(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// startEdit activates inline editing for a field on the currently selected auth file.
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
if m.cursor >= len(m.files) {
return nil
}
f := m.files[m.cursor]
m.editFileName = getString(f, "name")
m.editField = fieldIdx
m.editing = true
// Pre-populate with current value
key := authEditableFields[fieldIdx].key
currentVal := getAnyString(f, key)
m.editInput.SetValue(currentVal)
m.editInput.Focus()
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
m.viewport.SetContent(m.renderContent())
return textinput.Blink
}
func (m *authTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 20
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m authTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m authTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help2")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if len(m.files) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
sb.WriteString("\n")
return sb.String()
}
for i, f := range m.files {
name := getString(f, "name")
channel := getString(f, "channel")
email := getString(f, "email")
disabled := getBool(f, "disabled")
statusIcon := successStyle.Render("●")
statusText := T("status_active")
if disabled {
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
statusText = T("status_disabled")
}
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
displayName := name
if len(displayName) > 24 {
displayName = displayName[:21] + "..."
}
displayEmail := email
if len(displayEmail) > 28 {
displayEmail = displayEmail[:25] + "..."
}
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
cursor, statusIcon, displayName, channel, displayEmail, statusText)
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
sb.WriteString("\n")
}
// Inline edit input
if m.editing && i == m.cursor {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
sb.WriteString("\n")
}
// Expanded detail view
if m.expanded == i {
sb.WriteString(m.renderDetail(f))
}
}
if m.status != "" {
sb.WriteString("\n")
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func (m authTabModel) renderDetail(f map[string]any) string {
var sb strings.Builder
labelStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("111")).
Bold(true)
valueStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
editableMarker := lipgloss.NewStyle().
Foreground(lipgloss.Color("214")).
Render(" ✎")
sb.WriteString(" ┌─────────────────────────────────────────────\n")
fields := []struct {
label string
key string
editable bool
}{
{"Name", "name", false},
{"Channel", "channel", false},
{"Email", "email", false},
{"Status", "status", false},
{"Status Msg", "status_message", false},
{"File Name", "file_name", false},
{"Auth Type", "auth_type", false},
{"Prefix", "prefix", true},
{"Proxy URL", "proxy_url", true},
{"Priority", "priority", true},
{"Project ID", "project_id", false},
{"Disabled", "disabled", false},
{"Created", "created_at", false},
{"Updated", "updated_at", false},
}
for _, field := range fields {
val := getAnyString(f, field.key)
if val == "" || val == "<nil>" {
if field.editable {
val = T("not_set")
} else {
continue
}
}
editMark := ""
if field.editable {
editMark = editableMarker
}
line := fmt.Sprintf(" │ %s %s%s",
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
valueStyle.Render(val),
editMark)
sb.WriteString(line)
sb.WriteString("\n")
}
sb.WriteString(" └─────────────────────────────────────────────\n")
return sb.String()
}
// getAnyString converts any value to its string representation.
func getAnyString(m map[string]any, key string) string {
v, ok := m[key]
if !ok || v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
value := m.editInput.Value()
fieldKey := authEditableFields[m.editField].key
fileName := m.editFileName
m.editing = false
m.editInput.Blur()
fields := map[string]any{}
if fieldKey == "priority" {
p, err := strconv.Atoi(value)
if err != nil {
return m, func() tea.Msg {
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
}
}
fields[fieldKey] = p
} else {
fields[fieldKey] = value
}
return m, func() tea.Msg {
err := m.client.PatchAuthFileFields(fileName, fields)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
}
case "esc":
m.editing = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
if idx < len(m.files) {
name := getString(m.files[idx], "name")
return m, func() tea.Msg {
err := m.client.DeleteAuthFile(name)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "j", "down":
if len(m.files) > 0 {
m.cursor = (m.cursor + 1) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.files) > 0 {
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter", " ":
if m.expanded == m.cursor {
m.expanded = -1
} else {
m.expanded = m.cursor
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "d", "D":
if m.cursor < len(m.files) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "e", "E":
if m.cursor < len(m.files) {
f := m.files[m.cursor]
name := getString(f, "name")
disabled := getBool(f, "disabled")
newDisabled := !disabled
return m, func() tea.Msg {
err := m.client.ToggleAuthFile(name, newDisabled)
if err != nil {
return authActionMsg{err: err}
}
action := T("enabled")
if newDisabled {
action = T("disabled")
}
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
}
}
return m, nil
case "1":
return m, m.startEdit(0) // prefix
case "2":
return m, m.startEdit(1) // proxy_url
case "3":
return m, m.startEdit(2) // priority
case "r":
m.status = ""
return m, m.fetchFiles
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}

20
internal/tui/browser.go Normal file
View File

@@ -0,0 +1,20 @@
package tui
import (
"os/exec"
"runtime"
)
// openBrowser opens the specified URL in the user's default browser.
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
default:
return exec.Command("xdg-open", url).Start()
}
}

400
internal/tui/client.go Normal file
View File

@@ -0,0 +1,400 @@
package tui
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// Client wraps HTTP calls to the management API.
type Client struct {
baseURL string
secretKey string
http *http.Client
}
// NewClient creates a new management API client.
func NewClient(port int, secretKey string) *Client {
return &Client{
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
secretKey: strings.TrimSpace(secretKey),
http: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// SetSecretKey updates management API bearer token used by this client.
func (c *Client) SetSecretKey(secretKey string) {
c.secretKey = strings.TrimSpace(secretKey)
}
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
url := c.baseURL + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, 0, err
}
if c.secretKey != "" {
req.Header.Set("Authorization", "Bearer "+c.secretKey)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, err
}
return data, resp.StatusCode, nil
}
func (c *Client) get(path string) ([]byte, error) {
data, code, err := c.doRequest("GET", path, nil)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PUT", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PATCH", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
// getJSON fetches a path and unmarshals JSON into a generic map.
func (c *Client) getJSON(path string) (map[string]any, error) {
data, err := c.get(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// postJSON sends a JSON body via POST and checks for errors.
func (c *Client) postJSON(path string, body any) error {
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("HTTP %d", code)
}
return nil
}
// GetConfig fetches the parsed config.
func (c *Client) GetConfig() (map[string]any, error) {
return c.getJSON("/v0/management/config")
}
// GetConfigYAML fetches the raw config.yaml content.
func (c *Client) GetConfigYAML() (string, error) {
data, err := c.get("/v0/management/config.yaml")
if err != nil {
return "", err
}
return string(data), nil
}
// PutConfigYAML uploads new config.yaml content.
func (c *Client) PutConfigYAML(yamlContent string) error {
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
wrapper, err := c.getJSON("/v0/management/auth-files")
if err != nil {
return nil, err
}
return extractList(wrapper, "files")
}
// DeleteAuthFile deletes a single auth file by name.
func (c *Client) DeleteAuthFile(name string) error {
query := url.Values{}
query.Set("name", name)
path := "/v0/management/auth-files?" + query.Encode()
_, code, err := c.doRequest("DELETE", path, nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// ToggleAuthFile enables or disables an auth file.
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
return err
}
// PatchAuthFileFields updates editable fields on an auth file.
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
fields["name"] = name
body, _ := json.Marshal(fields)
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
return err
}
// GetLogs fetches log lines from the server.
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
query := url.Values{}
if limit > 0 {
query.Set("limit", strconv.Itoa(limit))
}
if after > 0 {
query.Set("after", strconv.FormatInt(after, 10))
}
path := "/v0/management/logs"
encodedQuery := query.Encode()
if encodedQuery != "" {
path += "?" + encodedQuery
}
wrapper, err := c.getJSON(path)
if err != nil {
return nil, after, err
}
lines := []string{}
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
rawJSON, errMarshal := json.Marshal(rawLines)
if errMarshal != nil {
return nil, after, errMarshal
}
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
return nil, after, errUnmarshal
}
}
latest := after
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
switch value := rawLatest.(type) {
case float64:
latest = int64(value)
case json.Number:
if parsed, errParse := value.Int64(); errParse == nil {
latest = parsed
}
case int64:
latest = value
case int:
latest = int64(value)
}
}
if latest < after {
latest = after
}
return lines, latest, nil
}
// GetAPIKeys fetches the list of API keys.
// API returns {"api-keys": [...]}.
func (c *Client) GetAPIKeys() ([]string, error) {
wrapper, err := c.getJSON("/v0/management/api-keys")
if err != nil {
return nil, err
}
arr, ok := wrapper["api-keys"]
if !ok {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []string
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
func (c *Client) AddAPIKey(key string) error {
body := map[string]any{"old": nil, "new": key}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// EditAPIKey replaces an API key at the given index.
func (c *Client) EditAPIKey(index int, newValue string) error {
body := map[string]any{"index": index, "value": newValue}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// DeleteAPIKey deletes an API key by index.
func (c *Client) DeleteAPIKey(index int) error {
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// GetGeminiKeys fetches Gemini API keys.
// API returns {"gemini-api-key": [...]}.
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
}
// GetClaudeKeys fetches Claude API keys.
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
}
// GetCodexKeys fetches Codex API keys.
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
}
// GetVertexKeys fetches Vertex API keys.
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
}
// GetOpenAICompat fetches OpenAI compatibility entries.
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
}
// getWrappedKeyList fetches a wrapped list from the API.
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
wrapper, err := c.getJSON(path)
if err != nil {
return nil, err
}
return extractList(wrapper, key)
}
// extractList pulls an array of maps from a wrapper object by key.
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
arr, ok := wrapper[key]
if !ok || arr == nil {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []map[string]any
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// GetDebug fetches the current debug setting.
func (c *Client) GetDebug() (bool, error) {
wrapper, err := c.getJSON("/v0/management/debug")
if err != nil {
return false, err
}
if v, ok := wrapper["debug"]; ok {
if b, ok := v.(bool); ok {
return b, nil
}
}
return false, nil
}
// GetAuthStatus polls the OAuth session status.
// Returns status ("wait", "ok", "error") and optional error message.
func (c *Client) GetAuthStatus(state string) (string, string, error) {
query := url.Values{}
query.Set("state", state)
path := "/v0/management/get-auth-status?" + query.Encode()
wrapper, err := c.getJSON(path)
if err != nil {
return "", "", err
}
status := getString(wrapper, "status")
errMsg := getString(wrapper, "error")
return status, errMsg, nil
}
// ----- Config field update methods -----
// PutBoolField updates a boolean config field.
func (c *Client) PutBoolField(path string, value bool) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutIntField updates an integer config field.
func (c *Client) PutIntField(path string, value int) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutStringField updates a string config field.
func (c *Client) PutStringField(path string, value string) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// DeleteField sends a DELETE request for a config field.
func (c *Client) DeleteField(path string) error {
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
return err
}

413
internal/tui/config_tab.go Normal file
View File

@@ -0,0 +1,413 @@
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// configField represents a single editable config field.
type configField struct {
label string
apiPath string // management API path (e.g. "debug", "proxy-url")
kind string // "bool", "int", "string", "readonly"
value string // current display value
rawValue any // raw value from API
}
// configTabModel displays parsed config with interactive editing.
type configTabModel struct {
client *Client
viewport viewport.Model
fields []configField
cursor int
editing bool
textInput textinput.Model
err error
message string // status message (success/error)
width int
height int
ready bool
}
type configDataMsg struct {
config map[string]any
err error
}
type configUpdateMsg struct {
path string
value any
err error
}
func newConfigTabModel(client *Client) configTabModel {
ti := textinput.New()
ti.CharLimit = 256
return configTabModel{
client: client,
textInput: ti,
}
}
func (m configTabModel) Init() tea.Cmd {
return m.fetchConfig
}
func (m configTabModel) fetchConfig() tea.Msg {
cfg, err := m.client.GetConfig()
return configDataMsg{config: cfg, err: err}
}
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case configDataMsg:
if msg.err != nil {
m.err = msg.err
m.fields = nil
} else {
m.err = nil
m.fields = m.parseConfig(msg.config)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case configUpdateMsg:
if msg.err != nil {
m.message = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.message = successStyle.Render(T("updated_ok"))
}
m.viewport.SetContent(m.renderContent())
// Refresh config from server
return m, m.fetchConfig
case tea.KeyMsg:
if m.editing {
return m.handleEditingKey(msg)
}
return m.handleNormalKey(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "r":
m.message = ""
return m, m.fetchConfig
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
// Ensure cursor is visible
m.ensureCursorVisible()
}
return m, nil
case "down", "j":
if m.cursor < len(m.fields)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
m.ensureCursorVisible()
}
return m, nil
case "enter", " ":
if m.cursor >= 0 && m.cursor < len(m.fields) {
f := m.fields[m.cursor]
if f.kind == "readonly" {
return m, nil
}
if f.kind == "bool" {
// Toggle directly
return m, m.toggleBool(m.cursor)
}
// Start editing for int/string
m.editing = true
m.textInput.SetValue(configFieldEditValue(f))
m.textInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
m.editing = false
m.textInput.Blur()
return m, m.submitEdit(m.cursor, m.textInput.Value())
case "esc":
m.editing = false
m.textInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m configTabModel) toggleBool(idx int) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
current := f.value == "true"
newValue := !current
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
return configUpdateMsg{
path: f.apiPath,
value: newValue,
err: errPutBool,
}
}
}
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
var err error
var value any
switch f.kind {
case "int":
valueInt, errAtoi := strconv.Atoi(newValue)
if errAtoi != nil {
return configUpdateMsg{
path: f.apiPath,
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
}
}
value = valueInt
err = m.client.PutIntField(f.apiPath, valueInt)
case "string":
value = newValue
err = m.client.PutStringField(f.apiPath, newValue)
}
return configUpdateMsg{
path: f.apiPath,
value: value,
err: err,
}
}
}
func configFieldEditValue(f configField) string {
if rawString, ok := f.rawValue.(string); ok {
return rawString
}
return f.value
}
func (m *configTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m *configTabModel) ensureCursorVisible() {
// Each field takes ~1 line, header takes ~4 lines
targetLine := m.cursor + 5
if targetLine < m.viewport.YOffset {
m.viewport.SetYOffset(targetLine)
}
if targetLine >= m.viewport.YOffset+m.viewport.Height {
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
}
}
func (m configTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m configTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("config_title")))
sb.WriteString("\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n")
}
sb.WriteString(helpStyle.Render(T("config_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("config_help2")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
return sb.String()
}
if len(m.fields) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_config")))
return sb.String()
}
currentSection := ""
for i, f := range m.fields {
// Section headers
section := fieldSection(f.apiPath)
if section != currentSection {
currentSection = section
sb.WriteString("\n")
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
sb.WriteString("\n")
}
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
labelStr := lipgloss.NewStyle().
Foreground(colorInfo).
Bold(isSelected).
Width(32).
Render(f.label)
var valueStr string
if m.editing && isSelected {
valueStr = m.textInput.View()
} else {
switch f.kind {
case "bool":
if f.value == "true" {
valueStr = successStyle.Render("● ON")
} else {
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
}
case "readonly":
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
default:
valueStr = valueStyle.Render(f.value)
}
}
line := prefix + labelStr + " " + valueStr
if isSelected && !m.editing {
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
}
sb.WriteString(line + "\n")
}
return sb.String()
}
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
var fields []configField
// Server settings
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
// Logging
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
// Quota exceeded
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
// Routing
if routing, ok := cfg["routing"].(map[string]any); ok {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
} else {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
}
// WebSocket auth
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
// AMP settings
if amp, ok := cfg["ampcode"].(map[string]any); ok {
upstreamURL := getString(amp, "upstream-url")
upstreamAPIKey := getString(amp, "upstream-api-key")
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
}
return fields
}
func fieldSection(apiPath string) string {
if strings.HasPrefix(apiPath, "ampcode/") {
return T("section_ampcode")
}
if strings.HasPrefix(apiPath, "quota-exceeded/") {
return T("section_quota")
}
if strings.HasPrefix(apiPath, "routing/") {
return T("section_routing")
}
switch apiPath {
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
return T("section_server")
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
return T("section_logging")
case "ws-auth":
return T("section_websocket")
default:
return T("section_other")
}
}
func getBoolNested(m map[string]any, keys ...string) bool {
current := m
for i, key := range keys {
if i == len(keys)-1 {
return getBool(current, key)
}
if nested, ok := current[key].(map[string]any); ok {
current = nested
} else {
return false
}
}
return false
}
func maskIfNotEmpty(s string) string {
if s == "" {
return T("not_set")
}
return maskKey(s)
}

360
internal/tui/dashboard.go Normal file
View File

@@ -0,0 +1,360 @@
package tui
import (
"encoding/json"
"fmt"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// dashboardModel displays server info, stats cards, and config overview.
type dashboardModel struct {
client *Client
viewport viewport.Model
content string
err error
width int
height int
ready bool
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
}
func newDashboardModel(client *Client) dashboardModel {
return dashboardModel{
client: client,
}
}
func (m dashboardModel) Init() tea.Cmd {
return m.fetchData
}
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
case dashboardDataMsg:
if msg.err != nil {
m.err = msg.err
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
} else {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *dashboardModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.content)
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m dashboardModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("dashboard_help")))
sb.WriteString("\n\n")
// ━━━ Connection Status ━━━
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
sb.WriteString(connStyle.Render(T("connected")))
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
sb.WriteString("\n\n")
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 18 {
cardWidth = 18
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(2)
// Card 1: API Keys
keyCount := len(apiKeys)
card1 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
))
// Card 2: Auth Files
authCount := len(authFiles)
activeAuth := 0
for _, f := range authFiles {
if !getBool(f, "disabled") {
activeAuth++
}
}
card2 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
if cfg != nil {
debug := getBool(cfg, "debug")
retry := getFloat(cfg, "request-retry")
proxyURL := getString(cfg, "proxy-url")
loggingToFile := getBool(cfg, "logging-to-file")
usageEnabled := true
if v, ok := cfg["usage-statistics-enabled"]; ok {
if b, ok2 := v.(bool); ok2 {
usageEnabled = b
}
}
configItems := []struct {
label string
value string
}{
{T("debug_mode"), boolEmoji(debug)},
{T("usage_stats"), boolEmoji(usageEnabled)},
{T("log_to_file"), boolEmoji(loggingToFile)},
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
}
if proxyURL != "" {
configItems = append(configItems, struct {
label string
value string
}{T("proxy_url"), proxyURL})
}
// Render config items as a compact row
for _, item := range configItems {
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(item.label+":"),
valueStyle.Render(item.value)))
}
// Routing strategy
strategy := "round-robin"
if routing, ok := cfg["routing"].(map[string]any); ok {
if s := getString(routing, "strategy"); s != "" {
strategy = s
}
}
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(T("routing_strategy")+":"),
valueStyle.Render(strategy)))
}
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
func formatKV(key, value string) string {
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
}
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getFloat(m map[string]any, key string) float64 {
if v, ok := m[key]; ok {
switch n := v.(type) {
case float64:
return n
case json.Number:
f, _ := n.Float64()
return f
}
}
return 0
}
func getBool(m map[string]any, key string) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
func boolEmoji(b bool) string {
if b {
return T("bool_yes")
}
return T("bool_no")
}
func formatLargeNumber(n int64) string {
if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
}
if n >= 1_000 {
return fmt.Sprintf("%.1fK", float64(n)/1_000)
}
return fmt.Sprintf("%d", n)
}
func truncate(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen-3] + "..."
}
return s
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

364
internal/tui/i18n.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
// i18n provides a simple internationalization system for the TUI.
// Supported locales: "zh" (Chinese, default), "en" (English).
var currentLocale = "en"
// SetLocale changes the active locale.
func SetLocale(locale string) {
if _, ok := locales[locale]; ok {
currentLocale = locale
}
}
// CurrentLocale returns the active locale code.
func CurrentLocale() string {
return currentLocale
}
// ToggleLocale switches between zh and en.
func ToggleLocale() {
if currentLocale == "zh" {
currentLocale = "en"
} else {
currentLocale = "zh"
}
}
// T returns the translated string for the given key.
func T(key string) string {
if m, ok := locales[currentLocale]; ok {
if v, ok := m[key]; ok {
return v
}
}
// Fallback to English
if m, ok := locales["en"]; ok {
if v, ok := m[key]; ok {
return v
}
}
return key
}
var locales = map[string]map[string]string{
"zh": zhStrings,
"en": enStrings,
}
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
if currentLocale == "zh" {
return zhTabNames
}
return enTabNames
}
var zhStrings = map[string]string{
// ── Common ──
"loading": "加载中...",
"refresh": "刷新",
"save": "保存",
"cancel": "取消",
"confirm": "确认",
"yes": "是",
"no": "否",
"error": "错误",
"success": "成功",
"navigate": "导航",
"scroll": "滚动",
"enter_save": "Enter: 保存",
"esc_cancel": "Esc: 取消",
"enter_submit": "Enter: 提交",
"press_r": "[r] 刷新",
"press_scroll": "[↑↓] 滚动",
"not_set": "(未设置)",
"error_prefix": "⚠ 错误: ",
// ── Status bar ──
"status_left": " CLIProxyAPI 管理终端",
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
"initializing_tui": "正在初始化...",
"auth_gate_title": "🔐 连接管理 API",
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
"auth_gate_password": "密码",
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
"auth_gate_connecting": "正在连接...",
"auth_gate_connect_fail": "连接失败:%s",
"auth_gate_password_required": "请输入密码",
// ── Dashboard ──
"dashboard_title": "📊 仪表盘",
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
"connected": "● 已连接",
"mgmt_keys": "管理密钥",
"auth_files_label": "认证文件",
"active_suffix": "活跃",
"total_requests": "请求",
"success_label": "成功",
"failure_label": "失败",
"total_tokens": "总 Tokens",
"current_config": "当前配置",
"debug_mode": "启用调试模式",
"usage_stats": "启用使用统计",
"log_to_file": "启用日志记录到文件",
"retry_count": "重试次数",
"proxy_url": "代理 URL",
"routing_strategy": "路由策略",
"model_stats": "模型统计",
"model": "模型",
"requests": "请求数",
"tokens": "Tokens",
"bool_yes": "是 ✓",
"bool_no": "否",
// ── Config ──
"config_title": "⚙ 配置",
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
"updated_ok": "✓ 更新成功",
"no_config": " 未加载配置",
"invalid_int": "无效整数",
"section_server": "服务器",
"section_logging": "日志与统计",
"section_quota": "配额超限处理",
"section_routing": "路由",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "其他",
// ── Auth Files ──
"auth_title": "🔑 认证文件",
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
"no_auth_files": " 无认证文件",
"confirm_delete": "⚠ 删除 %s? [y/n]",
"deleted": "已删除 %s",
"enabled": "已启用",
"disabled": "已停用",
"updated_field": "已更新 %s 的 %s",
"status_active": "活跃",
"status_disabled": "已停用",
// ── API Keys ──
"keys_title": "🔐 API 密钥",
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
"no_keys": " 无 API Key按 [a] 添加",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
"key_added": "已添加 API Key",
"key_updated": "已更新 API Key",
"key_deleted": "已删除 API Key",
"copied": "✓ 已复制到剪贴板",
"copy_failed": "✗ 复制失败",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: 添加 • Esc: 取消",
"enter_save_esc": " Enter: 保存 • Esc: 取消",
// ── OAuth ──
"oauth_title": "🔐 OAuth 登录",
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
"oauth_completed": "认证流程已完成。",
"oauth_failed": "认证失败",
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
"oauth_press_esc": " 按 [Esc] 取消",
"oauth_auth_url": " 授权链接:",
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
"oauth_callback_url": " 回调 URL:",
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
"oauth_submitting": "⏳ 提交回调中...",
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
"oauth_submit_fail": "✗ 提交回调失败",
"oauth_waiting": " 等待认证中...",
// ── Usage ──
"usage_title": "📈 使用统计",
"usage_help": " [r] 刷新 • [↑↓] 滚动",
"usage_no_data": " 使用数据不可用",
"usage_total_reqs": "总请求数",
"usage_total_tokens": "总 Token 数",
"usage_success": "成功",
"usage_failure": "失败",
"usage_total_token_l": "总Token",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "请求趋势 (按小时)",
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
"usage_req_by_day": "请求趋势 (按天)",
"usage_api_detail": "API 详细统计",
"usage_input": "输入",
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
// ── Logs ──
"logs_title": "📋 日志",
"logs_auto_scroll": "● 自动滚动",
"logs_paused": "○ 已暂停",
"logs_filter": "过滤",
"logs_lines": "行数",
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
"logs_waiting": " 等待日志输出...",
}
var enStrings = map[string]string{
// ── Common ──
"loading": "Loading...",
"refresh": "Refresh",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"yes": "Yes",
"no": "No",
"error": "Error",
"success": "Success",
"navigate": "Navigate",
"scroll": "Scroll",
"enter_save": "Enter: Save",
"esc_cancel": "Esc: Cancel",
"enter_submit": "Enter: Submit",
"press_r": "[r] Refresh",
"press_scroll": "[↑↓] Scroll",
"not_set": "(not set)",
"error_prefix": "⚠ Error: ",
// ── Status bar ──
"status_left": " CLIProxyAPI Management TUI",
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
"initializing_tui": "Initializing...",
"auth_gate_title": "🔐 Connect Management API",
"auth_gate_help": " Enter management password and press Enter to connect",
"auth_gate_password": "Password",
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
"auth_gate_connecting": "Connecting...",
"auth_gate_connect_fail": "Connection failed: %s",
"auth_gate_password_required": "password is required",
// ── Dashboard ──
"dashboard_title": "📊 Dashboard",
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
"connected": "● Connected",
"mgmt_keys": "Mgmt Keys",
"auth_files_label": "Auth Files",
"active_suffix": "active",
"total_requests": "Requests",
"success_label": "Success",
"failure_label": "Failed",
"total_tokens": "Total Tokens",
"current_config": "Current Config",
"debug_mode": "Debug Mode",
"usage_stats": "Usage Statistics",
"log_to_file": "Log to File",
"retry_count": "Retry Count",
"proxy_url": "Proxy URL",
"routing_strategy": "Routing Strategy",
"model_stats": "Model Stats",
"model": "Model",
"requests": "Requests",
"tokens": "Tokens",
"bool_yes": "Yes ✓",
"bool_no": "No",
// ── Config ──
"config_title": "⚙ Configuration",
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
"updated_ok": "✓ Updated successfully",
"no_config": " No configuration loaded",
"invalid_int": "invalid integer",
"section_server": "Server",
"section_logging": "Logging & Stats",
"section_quota": "Quota Exceeded Handling",
"section_routing": "Routing",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "Other",
// ── Auth Files ──
"auth_title": "🔑 Auth Files",
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
"no_auth_files": " No auth files found",
"confirm_delete": "⚠ Delete %s? [y/n]",
"deleted": "Deleted %s",
"enabled": "Enabled",
"disabled": "Disabled",
"updated_field": "Updated %s on %s",
"status_active": "active",
"status_disabled": "disabled",
// ── API Keys ──
"keys_title": "🔐 API Keys",
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
"no_keys": " No API Keys. Press [a] to add",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ Delete %s? [y/n]",
"key_added": "API Key added",
"key_updated": "API Key updated",
"key_deleted": "API Key deleted",
"copied": "✓ Copied to clipboard",
"copy_failed": "✗ Copy failed",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: Add • Esc: Cancel",
"enter_save_esc": " Enter: Save • Esc: Cancel",
// ── OAuth ──
"oauth_title": "🔐 OAuth Login",
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
"oauth_initiating": "⏳ Initiating %s login...",
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
"oauth_completed": "Authentication flow completed.",
"oauth_failed": "Authentication failed",
"oauth_timeout": "OAuth flow timed out (5 minutes)",
"oauth_press_esc": " Press [Esc] to cancel",
"oauth_auth_url": " Authorization URL:",
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
"oauth_callback_url": " Callback URL:",
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
"oauth_submitting": "⏳ Submitting callback...",
"oauth_submit_ok": "✓ Callback submitted, waiting...",
"oauth_submit_fail": "✗ Callback submission failed",
"oauth_waiting": " Waiting for authentication...",
// ── Usage ──
"usage_title": "📈 Usage Statistics",
"usage_help": " [r] Refresh • [↑↓] Scroll",
"usage_no_data": " Usage data not available",
"usage_total_reqs": "Total Requests",
"usage_total_tokens": "Total Tokens",
"usage_success": "Success",
"usage_failure": "Failed",
"usage_total_token_l": "Total Tokens",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "Requests by Hour",
"usage_tok_by_hour": "Token Usage by Hour",
"usage_req_by_day": "Requests by Day",
"usage_api_detail": "API Detail Statistics",
"usage_input": "Input",
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
// ── Logs ──
"logs_title": "📋 Logs",
"logs_auto_scroll": "● AUTO-SCROLL",
"logs_paused": "○ PAUSED",
"logs_filter": "Filter",
"logs_lines": "Lines",
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
"logs_waiting": " Waiting for log output...",
}

405
internal/tui/keys_tab.go Normal file
View File

@@ -0,0 +1,405 @@
package tui
import (
"fmt"
"strings"
"github.com/atotto/clipboard"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// keysTabModel displays and manages API keys.
type keysTabModel struct {
client *Client
viewport viewport.Model
keys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
width int
height int
ready bool
cursor int
confirm int // -1 = no deletion pending
status string
// Editing / Adding
editing bool
adding bool
editIdx int
editInput textinput.Model
}
type keysDataMsg struct {
apiKeys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
}
type keyActionMsg struct {
action string
err error
}
func newKeysTabModel(client *Client) keysTabModel {
ti := textinput.New()
ti.CharLimit = 512
ti.Prompt = " Key: "
return keysTabModel{
client: client,
confirm: -1,
editInput: ti,
}
}
func (m keysTabModel) Init() tea.Cmd {
return m.fetchKeys
}
func (m keysTabModel) fetchKeys() tea.Msg {
result := keysDataMsg{}
apiKeys, err := m.client.GetAPIKeys()
if err != nil {
result.err = err
return result
}
result.apiKeys = apiKeys
result.gemini, _ = m.client.GetGeminiKeys()
result.claude, _ = m.client.GetClaudeKeys()
result.codex, _ = m.client.GetCodexKeys()
result.vertex, _ = m.client.GetVertexKeys()
result.openai, _ = m.client.GetOpenAICompat()
return result
}
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case keysDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.keys = msg.apiKeys
m.gemini = msg.gemini
m.claude = msg.claude
m.codex = msg.codex
m.vertex = msg.vertex
m.openai = msg.openai
if m.cursor >= len(m.keys) {
m.cursor = max(0, len(m.keys)-1)
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case keyActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchKeys
case tea.KeyMsg:
// ---- Editing / Adding mode ----
if m.editing || m.adding {
switch msg.String() {
case "enter":
value := strings.TrimSpace(m.editInput.Value())
if value == "" {
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
}
isAdding := m.adding
editIdx := m.editIdx
m.editing = false
m.adding = false
m.editInput.Blur()
if isAdding {
return m, func() tea.Msg {
err := m.client.AddAPIKey(value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_added")}
}
}
return m, func() tea.Msg {
err := m.client.EditAPIKey(editIdx, value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_updated")}
}
case "esc":
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Delete confirmation ----
if m.confirm >= 0 {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
return m, func() tea.Msg {
err := m.client.DeleteAPIKey(idx)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_deleted")}
}
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
// ---- Normal mode ----
switch msg.String() {
case "j", "down":
if len(m.keys) > 0 {
m.cursor = (m.cursor + 1) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.keys) > 0 {
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "a":
// Add new key
m.adding = true
m.editing = false
m.editInput.SetValue("")
m.editInput.Prompt = T("new_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "e":
// Edit selected key
if m.cursor < len(m.keys) {
m.editing = true
m.adding = false
m.editIdx = m.cursor
m.editInput.SetValue(m.keys[m.cursor])
m.editInput.Prompt = T("edit_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
case "d":
// Delete selected key
if m.cursor < len(m.keys) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "c":
// Copy selected key to clipboard
if m.cursor < len(m.keys) {
key := m.keys[m.cursor]
if err := clipboard.WriteAll(key); err != nil {
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
} else {
m.status = successStyle.Render(T("copied"))
}
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "r":
m.status = ""
return m, m.fetchKeys
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *keysTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m keysTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m keysTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("keys_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("keys_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
// ━━━ Access API Keys (interactive) ━━━
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
sb.WriteString("\n")
if len(m.keys) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_keys")))
sb.WriteString("\n")
}
for i, key := range m.keys {
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
sb.WriteString("\n")
}
// Edit input
if m.editing && m.editIdx == i {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
sb.WriteString("\n")
}
}
// Add input
if m.adding {
sb.WriteString("\n")
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_add")))
sb.WriteString("\n")
}
sb.WriteString("\n")
// ━━━ Provider Keys (read-only display) ━━━
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
renderProviderKeys(&sb, "Claude API Keys", m.claude)
renderProviderKeys(&sb, "Codex API Keys", m.codex)
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
if len(m.openai) > 0 {
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
for i, entry := range m.openai {
name := getString(entry, "name")
baseURL := getString(entry, "base-url")
prefix := getString(entry, "prefix")
info := name
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
if m.status != "" {
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func renderSection(sb *strings.Builder, title string, count int) {
header := fmt.Sprintf("%s (%d)", title, count)
sb.WriteString(tableHeaderStyle.Render(" " + header))
sb.WriteString("\n")
}
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
if len(keys) == 0 {
return
}
renderSection(sb, title, len(keys))
for i, key := range keys {
apiKey := getString(key, "api-key")
prefix := getString(key, "prefix")
baseURL := getString(key, "base-url")
info := maskKey(apiKey)
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
func maskKey(key string) string {
if len(key) <= 8 {
return strings.Repeat("*", len(key))
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}

78
internal/tui/loghook.go Normal file
View File

@@ -0,0 +1,78 @@
package tui
import (
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
)
// LogHook is a logrus hook that captures log entries and sends them to a channel.
type LogHook struct {
ch chan string
formatter log.Formatter
mu sync.Mutex
levels []log.Level
}
// NewLogHook creates a new LogHook with a buffered channel of the given size.
func NewLogHook(bufSize int) *LogHook {
return &LogHook{
ch: make(chan string, bufSize),
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
levels: log.AllLevels,
}
}
// SetFormatter sets a custom formatter for the hook.
func (h *LogHook) SetFormatter(f log.Formatter) {
h.mu.Lock()
defer h.mu.Unlock()
h.formatter = f
}
// Levels returns the log levels this hook should fire on.
func (h *LogHook) Levels() []log.Level {
return h.levels
}
// Fire is called by logrus when a log entry is fired.
func (h *LogHook) Fire(entry *log.Entry) error {
h.mu.Lock()
f := h.formatter
h.mu.Unlock()
var line string
if f != nil {
b, err := f.Format(entry)
if err == nil {
line = strings.TrimRight(string(b), "\n\r")
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
// Non-blocking send
select {
case h.ch <- line:
default:
// Drop oldest if full
select {
case <-h.ch:
default:
}
select {
case h.ch <- line:
default:
}
}
return nil
}
// Chan returns the channel to read log lines from.
func (h *LogHook) Chan() <-chan string {
return h.ch
}

261
internal/tui/logs_tab.go Normal file
View File

@@ -0,0 +1,261 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
)
// logsTabModel displays real-time log lines from hook/API source.
type logsTabModel struct {
client *Client
hook *LogHook
viewport viewport.Model
lines []string
maxLines int
autoScroll bool
width int
height int
ready bool
filter string // "", "debug", "info", "warn", "error"
after int64
lastErr error
}
type logsPollMsg struct {
lines []string
latest int64
err error
}
type logsTickMsg struct{}
type logLineMsg string
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
return logsTabModel{
client: client,
hook: hook,
maxLines: 5000,
autoScroll: true,
}
}
func (m logsTabModel) Init() tea.Cmd {
if m.hook != nil {
return m.waitForLog
}
return m.fetchLogs
}
func (m logsTabModel) fetchLogs() tea.Msg {
lines, latest, err := m.client.GetLogs(m.after, 200)
return logsPollMsg{
lines: lines,
latest: latest,
err: err,
}
}
func (m logsTabModel) waitForNextPoll() tea.Cmd {
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
return logsTickMsg{}
})
}
func (m logsTabModel) waitForLog() tea.Msg {
if m.hook == nil {
return nil
}
line, ok := <-m.hook.Chan()
if !ok {
return nil
}
return logLineMsg(line)
}
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderLogs())
return m, nil
case logsTickMsg:
if m.hook != nil {
return m, nil
}
return m, m.fetchLogs
case logsPollMsg:
if m.hook != nil {
return m, nil
}
if msg.err != nil {
m.lastErr = msg.err
} else {
m.lastErr = nil
m.after = msg.latest
if len(msg.lines) > 0 {
m.lines = append(m.lines, msg.lines...)
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
}
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForNextPoll()
case logLineMsg:
m.lines = append(m.lines, string(msg))
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForLog
case tea.KeyMsg:
switch msg.String() {
case "a":
m.autoScroll = !m.autoScroll
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, nil
case "c":
m.lines = nil
m.lastErr = nil
m.viewport.SetContent(m.renderLogs())
return m, nil
case "1":
m.filter = ""
m.viewport.SetContent(m.renderLogs())
return m, nil
case "2":
m.filter = "info"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "3":
m.filter = "warn"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "4":
m.filter = "error"
m.viewport.SetContent(m.renderLogs())
return m, nil
default:
wasAtBottom := m.viewport.AtBottom()
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
// If user scrolls up, disable auto-scroll
if !m.viewport.AtBottom() && wasAtBottom {
m.autoScroll = false
}
// If user scrolls to bottom, re-enable auto-scroll
if m.viewport.AtBottom() {
m.autoScroll = true
}
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *logsTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderLogs())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m logsTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m logsTabModel) renderLogs() string {
var sb strings.Builder
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
if !m.autoScroll {
scrollStatus = warningStyle.Render(T("logs_paused"))
}
filterLabel := "ALL"
if m.filter != "" {
filterLabel = strings.ToUpper(m.filter) + "+"
}
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
sb.WriteString(titleStyle.Render(header))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("logs_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.lastErr != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
sb.WriteString("\n")
}
if len(m.lines) == 0 {
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
return sb.String()
}
for _, line := range m.lines {
if m.filter != "" && !m.matchLevel(line) {
continue
}
styled := m.styleLine(line)
sb.WriteString(styled)
sb.WriteString("\n")
}
return sb.String()
}
func (m logsTabModel) matchLevel(line string) bool {
switch m.filter {
case "error":
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
case "warn":
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
case "info":
return !strings.Contains(line, "[debug]")
default:
return true
}
}
func (m logsTabModel) styleLine(line string) string {
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
return logErrorStyle.Render(line)
}
if strings.Contains(line, "[warn") {
return logWarnStyle.Render(line)
}
if strings.Contains(line, "[info") {
return logInfoStyle.Render(line)
}
if strings.Contains(line, "[debug]") {
return logDebugStyle.Render(line)
}
return line
}

473
internal/tui/oauth_tab.go Normal file
View File

@@ -0,0 +1,473 @@
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// oauthProvider represents an OAuth provider option.
type oauthProvider struct {
name string
apiPath string // management API path
emoji string
}
var oauthProviders = []oauthProvider{
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"},
{"Qwen", "qwen-auth-url", "🟨"},
{"Kimi", "kimi-auth-url", "🟫"},
{"IFlow", "iflow-auth-url", "⬜"},
}
// oauthTabModel handles OAuth login flows.
type oauthTabModel struct {
client *Client
viewport viewport.Model
cursor int
state oauthState
message string
err error
width int
height int
ready bool
// Remote browser mode
authURL string // auth URL to display
authState string // OAuth state parameter
providerName string // current provider name
callbackInput textinput.Model
inputActive bool // true when user is typing callback URL
}
type oauthState int
const (
oauthIdle oauthState = iota
oauthPending
oauthRemote // remote browser mode: waiting for manual callback
oauthSuccess
oauthError
)
// Messages
type oauthStartMsg struct {
url string
state string
providerName string
err error
}
type oauthPollMsg struct {
done bool
message string
err error
}
type oauthCallbackSubmitMsg struct {
err error
}
func newOAuthTabModel(client *Client) oauthTabModel {
ti := textinput.New()
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
ti.CharLimit = 2048
ti.Prompt = " 回调 URL: "
return oauthTabModel{
client: client,
callbackInput: ti,
}
}
func (m oauthTabModel) Init() tea.Cmd {
return nil
}
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthStartMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.viewport.SetContent(m.renderContent())
return m, nil
}
m.authURL = msg.url
m.authState = msg.state
m.providerName = msg.providerName
m.state = oauthRemote
m.callbackInput.SetValue("")
m.callbackInput.Focus()
m.inputActive = true
m.message = ""
m.viewport.SetContent(m.renderContent())
// Also start polling in the background
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
case oauthPollMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.inputActive = false
m.callbackInput.Blur()
} else if msg.done {
m.state = oauthSuccess
m.message = successStyle.Render("✓ " + msg.message)
m.inputActive = false
m.callbackInput.Blur()
} else {
m.message = warningStyle.Render("⏳ " + msg.message)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthCallbackSubmitMsg:
if msg.err != nil {
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
} else {
m.message = successStyle.Render(T("oauth_submit_ok"))
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
// ---- Input active: typing callback URL ----
if m.inputActive {
switch msg.String() {
case "enter":
callbackURL := m.callbackInput.Value()
if callbackURL == "" {
return m, nil
}
m.inputActive = false
m.callbackInput.Blur()
m.message = warningStyle.Render(T("oauth_submitting"))
m.viewport.SetContent(m.renderContent())
return m, m.submitCallback(callbackURL)
case "esc":
m.inputActive = false
m.callbackInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.callbackInput, cmd = m.callbackInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Remote mode but not typing ----
if m.state == oauthRemote {
switch msg.String() {
case "c", "C":
// Re-activate input
m.inputActive = true
m.callbackInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "esc":
m.state = oauthIdle
m.message = ""
m.authURL = ""
m.authState = ""
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// ---- Pending (auto polling) ----
if m.state == oauthPending {
if msg.String() == "esc" {
m.state = oauthIdle
m.message = ""
m.viewport.SetContent(m.renderContent())
}
return m, nil
}
// ---- Idle ----
switch msg.String() {
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "down", "j":
if m.cursor < len(oauthProviders)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter":
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
provider := oauthProviders[m.cursor]
m.state = oauthPending
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
m.viewport.SetContent(m.renderContent())
return m, m.startOAuth(provider)
}
return m, nil
case "esc":
m.state = oauthIdle
m.message = ""
m.err = nil
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
return func() tea.Msg {
// Call the auth URL endpoint with is_webui=true
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
if err != nil {
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
}
authURL := getString(data, "url")
state := getString(data, "state")
if authURL == "" {
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
}
// Try to open browser (best effort)
_ = openBrowser(authURL)
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
}
}
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
return func() tea.Msg {
// Determine provider from current context
providerKey := ""
for _, p := range oauthProviders {
if p.name == m.providerName {
// Map provider name to the canonical key the API expects
switch p.apiPath {
case "gemini-cli-auth-url":
providerKey = "gemini"
case "anthropic-auth-url":
providerKey = "anthropic"
case "codex-auth-url":
providerKey = "codex"
case "antigravity-auth-url":
providerKey = "antigravity"
case "qwen-auth-url":
providerKey = "qwen"
case "kimi-auth-url":
providerKey = "kimi"
case "iflow-auth-url":
providerKey = "iflow"
}
break
}
}
body := map[string]string{
"provider": providerKey,
"redirect_url": callbackURL,
"state": m.authState,
}
err := m.client.postJSON("/v0/management/oauth-callback", body)
if err != nil {
return oauthCallbackSubmitMsg{err: err}
}
return oauthCallbackSubmitMsg{}
}
}
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
return func() tea.Msg {
// Poll session status for up to 5 minutes
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
}
time.Sleep(2 * time.Second)
status, errMsg, err := m.client.GetAuthStatus(state)
if err != nil {
continue // Ignore transient errors
}
switch status {
case "ok":
return oauthPollMsg{
done: true,
message: T("oauth_success"),
}
case "error":
return oauthPollMsg{
done: false,
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
}
case "wait":
continue
default:
return oauthPollMsg{
done: true,
message: T("oauth_completed"),
}
}
}
}
}
func (m *oauthTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.callbackInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m oauthTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m oauthTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("oauth_title")))
sb.WriteString("\n\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n\n")
}
// ---- Remote browser mode ----
if m.state == oauthRemote {
sb.WriteString(m.renderRemoteMode())
return sb.String()
}
if m.state == oauthPending {
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
return sb.String()
}
sb.WriteString(helpStyle.Render(T("oauth_select")))
sb.WriteString("\n\n")
for i, p := range oauthProviders {
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
label := fmt.Sprintf("%s %s", p.emoji, p.name)
if isSelected {
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
} else {
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
}
sb.WriteString(prefix + label + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_help")))
return sb.String()
}
func (m oauthTabModel) renderRemoteMode() string {
var sb strings.Builder
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
sb.WriteString("\n\n")
// Auth URL section
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
sb.WriteString("\n")
// Wrap URL to fit terminal width
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
maxURLWidth := m.width - 6
if maxURLWidth < 40 {
maxURLWidth = 40
}
wrappedURL := wrapText(m.authURL, maxURLWidth)
for _, line := range wrappedURL {
sb.WriteString(" " + urlStyle.Render(line) + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
sb.WriteString("\n\n")
// Callback URL input
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
sb.WriteString("\n")
if m.inputActive {
sb.WriteString(m.callbackInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
} else {
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
}
sb.WriteString("\n\n")
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
return sb.String()
}
// wrapText splits a long string into lines of at most maxWidth characters.
func wrapText(s string, maxWidth int) []string {
if maxWidth <= 0 {
return []string{s}
}
var lines []string
for len(s) > maxWidth {
lines = append(lines, s[:maxWidth])
s = s[maxWidth:]
}
if len(s) > 0 {
lines = append(lines, s)
}
return lines
}

126
internal/tui/styles.go Normal file
View File

@@ -0,0 +1,126 @@
// Package tui provides a terminal-based management interface for CLIProxyAPI.
package tui
import "github.com/charmbracelet/lipgloss"
// Color palette
var (
colorPrimary = lipgloss.Color("#7C3AED") // violet
colorSecondary = lipgloss.Color("#6366F1") // indigo
colorSuccess = lipgloss.Color("#22C55E") // green
colorWarning = lipgloss.Color("#EAB308") // yellow
colorError = lipgloss.Color("#EF4444") // red
colorInfo = lipgloss.Color("#3B82F6") // blue
colorMuted = lipgloss.Color("#6B7280") // gray
colorBg = lipgloss.Color("#1E1E2E") // dark bg
colorSurface = lipgloss.Color("#313244") // slightly lighter
colorText = lipgloss.Color("#CDD6F4") // light text
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
colorBorder = lipgloss.Color("#45475A") // border
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
)
// Tab bar styles
var (
tabActiveStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Padding(0, 2)
tabInactiveStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
Padding(0, 2)
tabBarStyle = lipgloss.NewStyle().
Background(colorSurface).
PaddingLeft(1).
PaddingBottom(0)
)
// Content styles
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
MarginBottom(1)
subtitleStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Italic(true)
labelStyle = lipgloss.NewStyle().
Foreground(colorInfo).
Bold(true).
Width(24)
valueStyle = lipgloss.NewStyle().
Foreground(colorText)
sectionStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorBorder).
Padding(1, 2)
errorStyle = lipgloss.NewStyle().
Foreground(colorError).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(colorSuccess)
warningStyle = lipgloss.NewStyle().
Foreground(colorWarning)
statusBarStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
PaddingLeft(1).
PaddingRight(1)
helpStyle = lipgloss.NewStyle().
Foreground(colorMuted)
)
// Log level styles
var (
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
)
// Table styles
var (
tableHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
BorderBottom(true).
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(colorBorder)
tableCellStyle = lipgloss.NewStyle().
Foreground(colorText).
PaddingRight(2)
tableSelectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Bold(true)
)
func logLevelStyle(level string) lipgloss.Style {
switch level {
case "debug":
return logDebugStyle
case "info":
return logInfoStyle
case "warn", "warning":
return logWarnStyle
case "error", "fatal", "panic":
return logErrorStyle
default:
return logInfoStyle
}
}

364
internal/tui/usage_tab.go Normal file
View File

@@ -0,0 +1,364 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}

View File

@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if o.Websockets != n.Websockets {
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
}

View File

@@ -164,6 +164,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if ck.Websockets {
attrs["websockets"] = "true"
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}

View File

@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
Config: &config.Config{
CodexKey: []config.CodexKey{
{
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
Websockets: true,
},
},
},
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if auths[0].Attributes["websockets"] != "true" {
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
}
}
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {

View File

@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
}
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
// Success! Set headers now.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
if len(chunk) > 0 {

View File

@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}

View File

@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk
if alt == "" {
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}

View File

@@ -52,6 +52,45 @@ const (
defaultStreamingBootstrapRetries = 0
)
type pinnedAuthContextKey struct{}
type selectedAuthCallbackContextKey struct{}
type executionSessionContextKey struct{}
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
authID = strings.TrimSpace(authID)
if authID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
}
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
if callback == nil {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
}
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
}
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
@@ -140,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
return retries
}
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
// Default is false.
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
return cfg != nil && cfg.PassthroughHeaders
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
@@ -152,7 +197,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
meta := map[string]any{idempotencyKeyMetadataKey: key}
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
}
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
}
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
return meta
}
func pinnedAuthIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(pinnedAuthContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
if ctx == nil {
return nil
}
raw := ctx.Value(selectedAuthCallbackContextKey{})
if callback, ok := raw.(func(string)); ok && callback != nil {
return callback
}
return nil
}
func executionSessionIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(executionSessionContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
// BaseAPIHandler contains the handlers for API endpoints.
@@ -371,10 +468,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -407,17 +504,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return resp.Payload, nil
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -450,20 +550,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
addon = hdr.Clone()
}
}
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
return resp.Payload, nil
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
close(errChan)
return nil, errChan
return nil, nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
@@ -482,7 +586,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
SourceFormat: sdktranslator.FromString(handlerType),
}
opts.Metadata = reqMeta
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
status := http.StatusInternalServerError
@@ -499,8 +603,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
close(errChan)
return nil, errChan
return nil, nil, errChan
}
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
var upstreamHeaders http.Header
if passthroughHeadersEnabled {
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
}
}
chunks := streamResult.Chunks
dataChan := make(chan []byte)
errChan := make(chan *interfaces.ErrorMessage, 1)
go func() {
@@ -574,9 +689,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
chunks = retryChunks
if passthroughHeadersEnabled {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
}
chunks = retryResult.Chunks
continue outer
}
streamErr = retryErr
@@ -607,7 +725,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
}
}()
return dataChan, errChan
return dataChan, upstreamHeaders, errChan
}
func statusFromError(err error) int {
@@ -667,13 +785,33 @@ func cloneBytes(src []byte) []byte {
return dst
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func replaceHeader(dst http.Header, src http.Header) {
for key := range dst {
delete(dst, key)
}
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
if msg != nil && msg.Addon != nil {
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
for key, values := range msg.Addon {
if len(values) == 0 {
continue

View File

@@ -0,0 +1,68 @@
package handlers
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
handler := NewBaseAPIHandlers(nil, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"req-1"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "" {
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
}
if got := recorder.Header().Get("X-Request-Id"); got != "" {
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
}
}
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Writer.Header().Set("X-Request-Id", "old-value")
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"new-1", "new-2"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "30" {
t.Fatalf("Retry-After = %q, want %q", got, "30")
}
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
}
}

View File

@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
call := e.calls
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
},
}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
Chunks: ch,
}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
Chunks: ch,
}, nil
}
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
e.mu.Unlock()
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
},
}
close(ch)
return ch, nil
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
@@ -122,6 +128,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
return e.calls
}
type authAwareStreamExecutor struct {
mu sync.Mutex
calls int
authIDs []string
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
_ = ctx
_ = req
_ = opts
ch := make(chan coreexecutor.StreamChunk, 1)
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
e.calls++
e.authIDs = append(e.authIDs, authID)
e.mu.Unlock()
if authID == "auth1" {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func (e *authAwareStreamExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.authIDs))
copy(out, e.authIDs)
return out
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
@@ -154,12 +236,78 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
PassthroughHeaders: true,
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
if upstreamAttemptHeader != "2" {
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
}
}
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -168,7 +316,6 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
@@ -178,8 +325,8 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
if upstreamHeaders != nil {
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
}
}
@@ -220,7 +367,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
@@ -252,3 +399,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
}
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
authIDs := executor.AuthIDs()
if len(authIDs) == 0 {
t.Fatalf("expected at least one upstream attempt")
}
for _, authID := range authIDs {
if authID != "auth1" {
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
}
}
}
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 0,
},
}, manager)
selectedAuthID := ""
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID
})
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if selectedAuthID != "auth2" {
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}

View File

@@ -0,0 +1,80 @@
package handlers
import (
"net/http"
"strings"
)
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
// be forwarded by proxies, plus security-sensitive headers that should not leak.
var hopByHopHeaders = map[string]struct{}{
// RFC 7230 hop-by-hop
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
// Security-sensitive
"Set-Cookie": {},
// CPA-managed (set by handlers, not upstream)
"Content-Length": {},
"Content-Encoding": {},
}
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
// headers removed. Returns nil if src is nil or empty after filtering.
func FilterUpstreamHeaders(src http.Header) http.Header {
if src == nil {
return nil
}
connectionScoped := connectionScopedHeaders(src)
dst := make(http.Header)
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
continue
}
if _, scoped := connectionScoped[canonicalKey]; scoped {
continue
}
dst[key] = values
}
if len(dst) == 0 {
return nil
}
return dst
}
func connectionScopedHeaders(src http.Header) map[string]struct{} {
scoped := make(map[string]struct{})
for _, rawValue := range src.Values("Connection") {
for _, token := range strings.Split(rawValue, ",") {
headerName := strings.TrimSpace(token)
if headerName == "" {
continue
}
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
}
}
return scoped
}
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
if src == nil {
return
}
for key, values := range src {
// Don't overwrite headers already set by CPA handlers
if dst.Get(key) != "" {
continue
}
for _, v := range values {
dst.Add(key, v)
}
}
}

View File

@@ -0,0 +1,55 @@
package handlers
import (
"net/http"
"testing"
)
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
src.Add("Connection", "x-hop-c")
src.Set("Keep-Alive", "timeout=5")
src.Set("X-Hop-A", "a")
src.Set("X-Hop-B", "b")
src.Set("X-Hop-C", "c")
src.Set("X-Request-Id", "req-1")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered == nil {
t.Fatalf("expected filtered headers, got nil")
}
requestID := filtered.Get("X-Request-Id")
if requestID != "req-1" {
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
}
blockedHeaderKeys := []string{
"Connection",
"Keep-Alive",
"X-Hop-A",
"X-Hop-B",
"X-Hop-C",
"Set-Cookie",
}
for _, key := range blockedHeaderKeys {
value := filtered.Get(key)
if value != "" {
t.Fatalf("expected %s to be removed, got %q", key, value)
}
}
}
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
src := http.Header{}
src.Add("Connection", "x-hop-a")
src.Set("X-Hop-A", "a")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered != nil {
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
}
}

View File

@@ -420,6 +420,7 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
// Check if this chunk has any meaningful content
hasContent := false
hasUsage := root.Get("usage").Exists()
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
@@ -438,8 +439,8 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
})
}
// If no meaningful content, return nil to indicate this chunk should be skipped
if !hasContent {
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
if !hasContent && !hasUsage {
return nil
}
@@ -498,6 +499,11 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
// Copy usage if present
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
return []byte(out)
}
@@ -513,12 +519,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -528,12 +535,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
if converted == nil {
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
@@ -569,7 +577,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -602,6 +610,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
@@ -610,6 +619,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
// Success! Commit to streaming headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
@@ -635,7 +645,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
var param any
setSSEHeaders := func() {
@@ -666,6 +676,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
@@ -673,6 +684,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, &param)
flusher.Flush()
@@ -698,13 +710,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
completionsResp := convertChatCompletionsResponseToCompletions(resp)
_, _ = c.Writer.Write(completionsResp)
cliCancel()
@@ -735,7 +748,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -766,6 +779,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
@@ -774,6 +788,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
converted := convertChatCompletionsStreamChunkToCompletions(chunk)

View File

@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
}
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
return nil, errors.New("not implemented")
}

View File

@@ -139,13 +139,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -164,13 +165,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
@@ -180,12 +182,13 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Con
modelName := gjson.GetBytes(chatJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
var param any
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, &param)
if converted == "" {
@@ -223,7 +226,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
// New core execution path
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
@@ -256,6 +259,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
if !ok {
// Stream closed without data? Send headers and done.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
@@ -264,6 +268,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
@@ -294,7 +299,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
modelName := gjson.GetBytes(chatJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
var param any
setSSEHeaders := func() {
@@ -324,6 +329,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
@@ -331,6 +337,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, &param)
flusher.Flush()

View File

@@ -0,0 +1,662 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
wsRequestTypeCreate = "response.create"
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsEventTypeDone = "response.done"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
)
var responsesWebsocketUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// ResponsesWebsocket handles websocket requests for /v1/responses.
// It accepts `response.create` and `response.append` requests and streams
// response events back as JSON websocket text messages.
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
if err != nil {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
}
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
}
setWebsocketRequestBody(c, wsBodyLog.String())
if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose)
}
}()
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
for {
msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else {
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
}
return
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
// log.Infof(
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
// passthroughSessionID,
// msgType,
// websocketPayloadEventType(payload),
// websocketPayloadPreview(payload),
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
}
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
payload,
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
passthroughSessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
if pinnedAuthID != "" {
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
} else {
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
pinnedAuthID = strings.TrimSpace(authID)
})
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
lastResponseOutput = completedOutput
}
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
return headers
}
// Keep the same sticky turn-state across reconnects when provided by the client.
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
if turnState != "" {
headers.Set(wsTurnStateHeader, turnState)
}
return headers
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
}
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
// log.Infof("responses websocket: response.create request")
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
}
}
}
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
if !gjson.GetBytes(normalized, "input").Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
}
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
if modelName == "" {
return nil, nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("missing model in response.create request"),
}
}
return normalized, bytes.Clone(normalized), nil
}
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request received before response.create"),
}
}
nextInput := gjson.GetBytes(rawJSON, "input")
if !nextInput.Exists() || !nextInput.IsArray() {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request requires array field: input"),
}
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
}
existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
}
}
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
var errSet error
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
if errSet != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
}
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(metadata) == 0 {
return false
}
raw, ok := metadata["websockets"]
if !ok || raw == nil {
return false
}
switch value := raw.(type) {
case bool:
return value
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
if errParse == nil {
return parsed
}
default:
}
return false
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
if existingRaw == "" {
existingRaw = "[]"
}
if appendRaw == "" {
appendRaw = "[]"
}
var existing []json.RawMessage
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
return "", err
}
var appendItems []json.RawMessage
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
return "", err
}
merged := append(existing, appendItems...)
out, err := json.Marshal(merged)
if err != nil {
return "", err
}
return string(out), nil
}
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
return "[]"
}
result := gjson.Parse(trimmed)
if result.Type == gjson.JSON && result.IsArray() {
return trimmed
}
return "[]"
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
c *gin.Context,
conn *websocket.Conn,
cancel handlers.APIHandlerCancelFunc,
data <-chan []byte,
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
// log.Warnf(
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
// sessionID,
// websocketPayloadEventType(errorPayload),
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
}
cancel(nil)
return completedOutput, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
}
}
}
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
return bytes.Clone([]byte(output.Raw))
}
return []byte("[]")
}
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
payloads := make([][]byte, 0, 2)
lines := bytes.Split(chunk, []byte("\n"))
for i := range lines {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimSpace(line[len("data:"):])
}
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
continue
}
if json.Valid(line) {
payloads = append(payloads, bytes.Clone(line))
}
}
if len(payloads) > 0 {
return payloads
}
trimmed := bytes.TrimSpace(chunk)
if bytes.HasPrefix(trimmed, []byte("data:")) {
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
}
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
payloads = append(payloads, bytes.Clone(trimmed))
}
return payloads
}
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if errMsg != nil {
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
errText = http.StatusText(status)
}
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return data, conn.WriteMessage(websocket.TextMessage, data)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func websocketPayloadEventType(payload []byte) string {
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if eventType == "" {
return "-"
}
return eventType
}
func websocketPayloadPreview(payload []byte) string {
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return "<empty>"
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}
func setWebsocketRequestBody(c *gin.Context, body string) {
if c == nil {
return
}
trimmedBody := strings.TrimSpace(body)
if trimmedBody == "" {
return
}
c.Set(wsRequestBodyKey, []byte(trimmedBody))
}
func markAPIResponseTimestamp(c *gin.Context) {
if c == nil {
return
}
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}

View File

@@ -0,0 +1,249 @@
package openai
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized create request must not include type field")
}
if !gjson.GetBytes(normalized, "stream").Bool() {
t.Fatalf("normalized create request must force stream=true")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if !bytes.Equal(last, normalized) {
t.Fatalf("last request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized subsequent create request must not include type field")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized request must not include type field")
}
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
t.Fatalf("previous_response_id must be preserved in incremental mode")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 1 {
t.Fatalf("incremental input len = %d, want 1", len(input))
}
if input[0].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1"},
{"type":"function_call_output","id":"tool-out-1"}
]`)
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 5 {
t.Fatalf("merged input len = %d, want 5", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "assistant-1" ||
input[2].Get("id").String() != "tool-out-1" ||
input[3].Get("id").String() != "msg-2" ||
input[4].Get("id").String() != "msg-3" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized append request")
}
}
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
raw := []byte(`{"type":"response.append","input":[]}`)
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg == nil {
t.Fatalf("expected error for append without previous request")
}
if errMsg.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
}
}
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestResponseCompletedOutputFromPayload(t *testing.T) {
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
output := responseCompletedOutputFromPayload(payload)
items := gjson.ParseBytes(output).Array()
if len(items) != 1 {
t.Fatalf("output len = %d, want 1", len(items))
}
if items[0].Get("id").String() != "out-1" {
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
}
}
func TestAppendWebsocketEvent(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
got := builder.String()
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
t.Fatalf("request event not found in body: %s", got)
}
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
t.Fatalf("response event not found in body: %s", got)
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists {
t.Fatalf("request body key should not be set for empty body")
}
setWebsocketRequestBody(c, "event body")
value, exists := c.Get(wsRequestBodyKey)
if !exists {
t.Fatalf("request body key not set")
}
bodyBytes, ok := value.([]byte)
if !ok {
t.Fatalf("request body key type mismatch")
}
if string(bodyBytes) != "event body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}

121
sdk/auth/kilo.go Normal file
View File

@@ -0,0 +1,121 @@
package auth
import (
"context"
"fmt"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// KiloAuthenticator implements the login flow for Kilo AI accounts.
type KiloAuthenticator struct{}
// NewKiloAuthenticator constructs a Kilo authenticator.
func NewKiloAuthenticator() *KiloAuthenticator {
return &KiloAuthenticator{}
}
func (a *KiloAuthenticator) Provider() string {
return "kilo"
}
func (a *KiloAuthenticator) RefreshLead() *time.Duration {
return nil
}
// Login manages the device flow authentication for Kilo AI.
func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
kilocodeAuth := kilo.NewKiloAuth()
fmt.Println("Initiating Kilo device authentication...")
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
return nil, fmt.Errorf("failed to initiate device flow: %w", err)
}
fmt.Printf("Please visit: %s\n", resp.VerificationURL)
fmt.Printf("And enter code: %s\n", resp.Code)
fmt.Println("Waiting for authorization...")
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
return nil, fmt.Errorf("authentication failed: %w", err)
}
fmt.Printf("Authentication successful for %s\n", status.UserEmail)
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
return nil, fmt.Errorf("failed to fetch profile: %w", err)
}
var orgID string
if len(profile.Orgs) > 1 {
fmt.Println("Multiple organizations found. Please select one:")
for i, org := range profile.Orgs {
fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID)
}
if opts.Prompt != nil {
input, err := opts.Prompt("Enter the number of the organization: ")
if err != nil {
return nil, err
}
var choice int
_, err = fmt.Sscan(input, &choice)
if err == nil && choice > 0 && choice <= len(profile.Orgs) {
orgID = profile.Orgs[choice-1].ID
} else {
orgID = profile.Orgs[0].ID
fmt.Printf("Invalid choice, defaulting to %s\n", profile.Orgs[0].Name)
}
} else {
orgID = profile.Orgs[0].ID
fmt.Printf("Non-interactive mode, defaulting to organization: %s\n", profile.Orgs[0].Name)
}
} else if len(profile.Orgs) == 1 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
fmt.Printf("Warning: failed to fetch defaults: %v\n", err)
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
metadata := map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: ts,
Metadata: metadata,
}, nil
}

View File

@@ -30,8 +30,9 @@ type ProviderExecutor interface {
Identifier() string
// Execute handles non-streaming execution and returns the provider response payload.
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
// ExecuteStream handles streaming execution and returns a StreamResult containing
// upstream headers and a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
// Refresh attempts to refresh provider credentials and returns the updated auth state.
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
// CountTokens returns the token count for the given request.
@@ -41,6 +42,17 @@ type ProviderExecutor interface {
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
const (
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
// Executors that do not support this marker may ignore it.
CloseAllExecutionSessionsID = "__all_execution_sessions__"
)
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
@@ -389,9 +401,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
provider := strings.TrimSpace(executor.Identifier())
if provider == "" {
return
}
var replaced ProviderExecutor
m.mu.Lock()
defer m.mu.Unlock()
m.executors[executor.Identifier()] = executor
replaced = m.executors[provider]
m.executors[provider] = executor
m.mu.Unlock()
if replaced == nil || replaced == executor {
return
}
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
}
}
// UnregisterExecutor removes the executor associated with the provider key.
@@ -533,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
// ExecuteStream performs a streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
@@ -543,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
var lastErr error
for attempt := 0; ; attempt++ {
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
if errStream == nil {
return chunks, nil
return result, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
@@ -581,6 +607,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -636,6 +663,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -672,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
@@ -691,6 +719,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -702,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
@@ -750,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
}
}
@@ -794,6 +826,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
callback(authID)
}
}
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model
@@ -1550,7 +1614,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
return auth.Clone(), true
}
// Executor returns the registered provider executor for a provider key.
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
if m == nil {
return nil, false
}
provider = strings.TrimSpace(provider)
if provider == "" {
return nil, false
}
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
lowerProvider := strings.ToLower(provider)
if lowerProvider != provider {
executor, okExecutor = m.executors[lowerProvider]
}
}
m.mu.RUnlock()
if !okExecutor || executor == nil {
return nil, false
}
return executor, true
}
// CloseExecutionSession asks all registered executors to release the supplied execution session.
func (m *Manager) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if m == nil || sessionID == "" {
return
}
m.mu.RLock()
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(sessionID)
}
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
@@ -1571,6 +1684,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
if candidate.Provider != provider || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
@@ -1606,6 +1722,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
@@ -1633,6 +1751,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
if candidate == nil || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue

View File

@@ -0,0 +1,104 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type replaceAwareExecutor struct {
id string
mu sync.Mutex
closedSessionIDs []string
}
func (e *replaceAwareExecutor) Identifier() string {
return e.id
}
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
ch := make(chan cliproxyexecutor.StreamChunk)
close(ch)
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
}
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, nil
}
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
e.mu.Lock()
defer e.mu.Unlock()
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
}
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.closedSessionIDs))
copy(out, e.closedSessionIDs)
return out
}
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
replaced := &replaceAwareExecutor{id: "codex"}
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(replaced)
manager.RegisterExecutor(current)
closed := replaced.ClosedSessionIDs()
if len(closed) != 1 {
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
}
if closed[0] != CloseAllExecutionSessionsID {
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
}
if len(current.ClosedSessionIDs()) != 0 {
t.Fatalf("expected current executor to stay open")
}
}
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(current)
resolved, okResolved := manager.Executor("CODEX")
if !okResolved {
t.Fatal("expected registered executor to be found")
}
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
if !okResolvedExecutor {
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
}
if resolvedExecutor != current {
t.Fatal("expected resolved executor to match registered executor")
}
_, okMissing := manager.Executor("unknown")
if okMissing {
t.Fatal("expected unknown provider lookup to fail")
}
}

View File

@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
return modelName
}
func authWebsocketsEnabled(auth *Auth) bool {
if auth == nil {
return false
}
if len(auth.Attributes) > 0 {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(auth.Metadata) == 0 {
return false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed
}
default:
}
return false
}
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
if len(available) == 0 {
return available
}
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
return available
}
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
return available
}
wsEnabled := make([]*Auth, 0, len(available))
for i := 0; i < len(available); i++ {
candidate := available[i]
if authWebsocketsEnabled(candidate) {
wsEnabled = append(wsEnabled, candidate)
}
}
if len(wsEnabled) > 0 {
return wsEnabled
}
return available
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
key := provider + ":" + canonicalModelKey(model)
s.mu.Lock()
if s.cursors == nil {
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
return available[0], nil
}

View File

@@ -213,6 +213,23 @@ func (a *Auth) DisableCoolingOverride() (bool, bool) {
return false, false
}
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
func (a *Auth) ToolPrefixDisabled() bool {
if a == nil || a.Metadata == nil {
return false
}
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
if val, ok := a.Metadata[key]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed
}
}
}
return false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {

View File

@@ -0,0 +1,35 @@
package auth
import "testing"
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
if a.ToolPrefixDisabled() {
t.Error("nil auth should return false")
}
a = &Auth{}
if a.ToolPrefixDisabled() {
t.Error("empty auth should return false")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to true")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to string 'true'")
}
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true with kebab-case key")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
if a.ToolPrefixDisabled() {
t.Error("should return false when set to false")
}
}

View File

@@ -0,0 +1,23 @@
package executor
import "context"
type downstreamWebsocketContextKey struct{}
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
func WithDownstreamWebsocket(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
}
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
func DownstreamWebsocket(ctx context.Context) bool {
if ctx == nil {
return false
}
raw := ctx.Value(downstreamWebsocketContextKey{})
enabled, ok := raw.(bool)
return ok && enabled
}

View File

@@ -10,6 +10,17 @@ import (
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
const RequestedModelMetadataKey = "requested_model"
const (
// PinnedAuthMetadataKey locks execution to a specific auth ID.
PinnedAuthMetadataKey = "pinned_auth_id"
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
SelectedAuthMetadataKey = "selected_auth_id"
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
ExecutionSessionMetadataKey = "execution_session_id"
)
// Request encapsulates the translated payload that will be sent to a provider executor.
type Request struct {
// Model is the upstream model identifier after translation.
@@ -46,6 +57,8 @@ type Response struct {
Payload []byte
// Metadata exposes optional structured data for translators.
Metadata map[string]any
// Headers carries upstream HTTP response headers for passthrough to clients.
Headers http.Header
}
// StreamChunk represents a single streaming payload unit emitted by provider executors.
@@ -56,6 +69,15 @@ type StreamChunk struct {
Err error
}
// StreamResult wraps the streaming response, providing both the chunk channel
// and the upstream HTTP response headers captured before streaming begins.
type StreamResult struct {
// Headers carries upstream HTTP response headers from the initial connection.
Headers http.Header
// Chunks is the channel of streaming payload units.
Chunks <-chan StreamChunk
}
// StatusError represents an error that carries an HTTP-like status code.
// Provider executors should implement this when possible to enable
// better auth state updates on failures (e.g., 401/402/429).

View File

@@ -336,6 +336,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
if _, err := s.coreManager.Update(ctx, existing); err != nil {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
s.ensureExecutorsForAuth(existing)
}
}
}
@@ -368,7 +371,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
}
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
if s == nil || a == nil {
s.ensureExecutorsForAuthWithMode(a, false)
}
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
if s == nil || s.coreManager == nil || a == nil {
return
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
if !forceReplace {
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
if hasExecutor {
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
if isCodexAutoExecutor {
return
}
}
}
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
return
}
// Skip disabled auth entries when (re)binding executors.
@@ -403,8 +423,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "codex":
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
@@ -413,6 +431,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
case "kiro":
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
case "kilo":
s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
default:
@@ -430,8 +450,15 @@ func (s *Service) rebindExecutors() {
return
}
auths := s.coreManager.List()
reboundCodex := false
for _, auth := range auths {
s.ensureExecutorsForAuth(auth)
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
if reboundCodex {
continue
}
reboundCodex = true
}
s.ensureExecutorsForAuthWithMode(auth, true)
}
}
@@ -844,6 +871,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kiro":
models = s.fetchKiroModels(a)
models = applyExcludedModels(models, excluded)
case "kilo":
models = executor.FetchKiloModels(context.Background(), a, s.cfg)
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {

View File

@@ -0,0 +1,64 @@
package cliproxy
import (
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-1",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuth(auth)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after second bind")
}
if firstExecutor != secondExecutor {
t.Fatal("expected codex executor to stay unchanged in normal mode")
}
}
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-2",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuthWithMode(auth, true)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after forced rebind")
}
if firstExecutor == secondExecutor {
t.Fatal("expected codex executor replacement in force mode")
}
}

View File

@@ -90,3 +90,26 @@ func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
}
}
func TestApplyOAuthModelAlias_DefaultGitHubCopilotAliasViaSanitize(t *testing.T) {
cfg := &config.Config{}
cfg.SanitizeOAuthModelAlias()
models := []*ModelInfo{
{ID: "claude-opus-4.6", Name: "models/claude-opus-4.6"},
}
out := applyOAuthModelAlias(cfg, "github-copilot", "oauth", models)
if len(out) != 2 {
t.Fatalf("expected 2 models (original + default alias), got %d", len(out))
}
if out[0].ID != "claude-opus-4.6" {
t.Fatalf("expected first model id %q, got %q", "claude-opus-4.6", out[0].ID)
}
if out[1].ID != "claude-opus-4-6" {
t.Fatalf("expected second model id %q, got %q", "claude-opus-4-6", out[1].ID)
}
if out[1].Name != "models/claude-opus-4-6" {
t.Fatalf("expected aliased model name %q, got %q", "models/claude-opus-4-6", out[1].Name)
}
}