Compare commits

...

120 Commits

Author SHA1 Message Date
Luis Pater
68934942d0 Merge branch 'pr-402-local'
# Conflicts:
#	internal/config/oauth_model_alias_migration.go
2026-03-02 20:45:37 +08:00
Luis Pater
09fec34e1c chore(docs): update sponsor info and GLM model details in README files 2026-03-02 20:30:07 +08:00
hkfires
9229708b6c revert(executor): re-apply PR #1735 antigravity changes with cleanup 2026-03-02 19:30:32 +08:00
hkfires
914db94e79 refactor(headers): streamline User-Agent handling and introduce GeminiCLI versioning 2026-03-02 13:04:30 +08:00
hkfires
660bd7eff5 refactor(config): remove oauth-model-alias migration logic and related tests 2026-03-02 13:02:15 +08:00
hkfires
b907d21851 revert(executor): revert antigravity_executor.go changes from PR #1735 2026-03-02 12:54:15 +08:00
Luis Pater
d6cc976d1f chore(executor): remove unused header scrubbing function 2026-03-02 03:40:54 +08:00
Luis Pater
8aa2cce8c5 Merge PR #1735 into dev with conflict resolution and fixes 2026-03-02 03:22:51 +08:00
Luis Pater
bf9b2c49df Merge branch 'router-for-me:main' into main 2026-03-01 21:40:14 +08:00
Luis Pater
77b42c6165 fix(claude): handle X-CPA-CLAUDE-1M header and ensure proper beta merging logic 2026-03-01 21:39:33 +08:00
Luis Pater
446150a747 Merge branch 'router-for-me:main' into main 2026-03-01 20:35:05 +08:00
Luis Pater
1cbc4834e1 Merge pull request #1771 from edlsh/fix/claude-cache-control-1769
Fix Claude OAuth cache_control regressions and gzip error decoding
2026-03-01 20:17:22 +08:00
hkfires
a8a5d03c33 chore: ignore .idea directory in git and docker builds 2026-03-01 12:42:59 +08:00
edlsh
76aa917882 Optimize cache-control JSON mutations in Claude executor 2026-02-28 22:47:04 -05:00
edlsh
6ac9b31e4e Handle compressed error decode failures safely 2026-02-28 22:43:46 -05:00
edlsh
0ad3e8457f Clarify cloaking system block cache-control comments 2026-02-28 22:34:14 -05:00
edlsh
444a47ae63 Fix Claude cache-control guardrails and gzip error decoding 2026-02-28 22:32:33 -05:00
Luis Pater
725f4fdff4 Merge pull request #1768 from router-for-me/claude
fix(translator): handle Claude thinking type "auto" like adaptive
2026-03-01 11:03:13 +08:00
Luis Pater
c23e46f45d Merge pull request #1767 from router-for-me/antigravity
fix(antigravity): update model configurations and add new models for Antigravity
2026-03-01 11:02:20 +08:00
hkfires
b148820c35 fix(translator): handle Claude thinking type "auto" like adaptive 2026-03-01 10:30:19 +08:00
hkfires
134f41496d fix(antigravity): update model configurations and add new models for Antigravity 2026-03-01 10:05:29 +08:00
Luis Pater
c5838dd58d Merge pull request #400 from router-for-me/plus
v6.8.35
2026-03-01 09:42:03 +08:00
Luis Pater
b6ca5ef7ce Merge branch 'main' into plus 2026-03-01 09:41:52 +08:00
Luis Pater
1ae994b4aa fix(antigravity): adjust thinkingBudget default to 64000 and update model definitions for Claude 2026-03-01 09:39:39 +08:00
Luis Pater
84e9793e61 Merge pull request #399 from router-for-me/plus
v6.8.34
2026-03-01 02:44:36 +08:00
Luis Pater
32e64dacfd Merge branch 'main' into plus 2026-03-01 02:44:26 +08:00
Luis Pater
cc1d8f6629 Fixed: #1747
feat(auth): add configurable max-retry-credentials for finer control over cross-credential retries
2026-03-01 02:42:36 +08:00
Luis Pater
5446cd2b02 Merge pull request #1761 from margbug01/fix/thinking-chain-display
fix: support thinking.type=auto from Amp client and decouple thinking translation from unsigned history
2026-03-01 02:30:42 +08:00
margbug01
8de0885b7d fix: support thinking.type="auto" from Amp client for Antigravity Claude models
## Problem

When using Antigravity Claude models through CLIProxyAPI, the thinking
chain (reasoning content) does not display in the Amp client.

## Root Cause

The Amp client sends `thinking: {"type": "auto"}` in its requests,
but `ConvertClaudeRequestToAntigravity` only handled `"enabled"` and
`"adaptive"` types in its switch statement. The `"auto"` type was
silently ignored, resulting in no `thinkingConfig` being set in the
translated Gemini request. Without `thinkingConfig`, the Antigravity
API returns responses without any thinking content.

Additionally, the Antigravity API for Claude models does not support
`thinkingBudget: -1` (auto mode sentinel). It requires a concrete
positive budget value. The fix uses 128000 as the budget for "auto"
mode, which `ApplyThinking` will then normalize to stay within the
model's actual limits (e.g., capped to `maxOutputTokens - 1`).

## Changes

### internal/translator/antigravity/claude/antigravity_claude_request.go

1. **Add "auto" case** to the thinking type switch statement.
   Sets `thinkingBudget: 128000` and `includeThoughts: true`.
   The budget is subsequently normalized by `ApplyThinking` based
   on model-specific limits.

2. **Add "auto" to hasThinking check** so that interleaved thinking
   hints are injected for tool-use scenarios when Amp sends
   `thinking.type="auto"`.

### internal/registry/model_definitions_static_data.go

3. **Add Thinking configuration** for `claude-sonnet-4-6`,
   `claude-sonnet-4-5`, and `claude-opus-4-6` in
   `GetAntigravityModelConfig()` -- these were previously missing,
   causing `ApplyThinking` to skip thinking config entirely.

## Testing

- Deployed to Railway test instance (cpa-thinking-test)
- Verified via debug logging that:
  - Amp sends `thinking: {"type": "auto"}`
  - CPA now translates this to `thinkingConfig: {thinkingBudget: 128000, includeThoughts: true}`
  - `ApplyThinking` normalizes the budget to model-specific limits
  - Antigravity API receives the correct thinkingConfig

Amp-Thread-ID: https://ampcode.com/threads/T-019ca511-710d-776d-a07c-4b750f871a93
Co-authored-by: Amp <amp@ampcode.com>
2026-03-01 02:18:43 +08:00
Luis Pater
16243f18fd Merge branch 'router-for-me:main' into main 2026-03-01 01:46:23 +08:00
Luis Pater
a6ce5f36e6 Fixed: #1758
fix(codex): filter billing headers from system result text and update template logic
2026-03-01 01:45:35 +08:00
Luis Pater
e73cf42e28 Merge pull request #1750 from tpm2dot0/fix/claude-code-request-fingerprint-alignment
fix(cloak): align outgoing requests with real Claude Code 2.1.63
2026-03-01 01:27:28 +08:00
exe.dev user
b45343e812 fix(cloak): align outgoing requests with real Claude Code 2.1.63 fingerprint
Captured and compared outgoing requests from CLIProxyAPI against real
Claude Code 2.1.63 and fixed all detectable differences:

Headers:
- Update anthropic-beta to match 2.1.63: replace fine-grained-tool-streaming
  and prompt-caching-2024-07-31 with context-management-2025-06-27 and
  prompt-caching-scope-2026-01-05
- Remove X-Stainless-Helper-Method header (real Claude Code does not send it)
- Update default User-Agent from "claude-cli/2.1.44 (external, sdk-cli)" to
  "claude-cli/2.1.63 (external, cli)"
- Force Claude Code User-Agent for non-Claude clients to avoid leaking
  real client identity (e.g. curl, OpenAI SDKs) during cloaking

Body:
- Inject x-anthropic-billing-header as system[0] (matches real format)
- Change system prompt identifier from "You are Claude Code..." to
  "You are a Claude agent, built on Anthropic's Claude Agent SDK."
- Add cache_control with ttl:"1h" to match real request format
- Fix user_id format: user_[64hex]_account_[uuid]_session_[uuid]
  (was missing account UUID)
- Disable tool name prefix (set claudeToolPrefix to empty string)

TLS:
- Switch utls fingerprint from HelloFirefox_Auto to HelloChrome_Auto
  (closer to Node.js/OpenSSL used by real Claude Code)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-28 09:19:06 +00:00
Luis Pater
8599b1560e Fixed: #1716
feat(kimi): add support for explicit disabled thinking and reasoning effort handling
2026-02-28 05:29:07 +08:00
Luis Pater
8bde8c37c0 Fixed: #1711
fix(server): use resolved log directory for request logger initialization and test fallback logic
2026-02-28 05:21:01 +08:00
Luis Pater
82df5bf88a Merge pull request #395 from Xm798/feat/kiro
feat(kiro): add IDC auth code flow, redesign fingerprint and API protocol
2026-02-27 20:50:43 +08:00
Luis Pater
acb1066de8 Merge branch 'router-for-me:main' into main 2026-02-27 20:49:03 +08:00
Luis Pater
27c68f5bb2 fix(auth): replace MarkResult with hook OnResult for result handling 2026-02-27 20:47:46 +08:00
maplelove
68dd2bfe82 fix(translator): allow passthrough of custom generationConfig for all Gemini-like providers 2026-02-27 17:13:42 +08:00
Luis Pater
65a87815e7 Merge pull request #394 from router-for-me/plus
v6.8.31
2026-02-27 16:18:48 +08:00
Luis Pater
b80793ca82 Merge branch 'main' into plus 2026-02-27 16:18:12 +08:00
Luis Pater
601550f238 Merge pull request #393 from cielhaidir/main
feat(kiro): add new Kiro models definition
2026-02-27 16:17:05 +08:00
Luis Pater
41b1cf2273 Merge pull request #1734 from huangusaki/main
feat(registry): add gemini-3.1-flash-image support
2026-02-27 16:12:05 +08:00
maplelove
2baf35b3ef fix(executor): bump antigravity UA to 1.19.6 and align image_gen payload 2026-02-27 14:09:37 +08:00
maplelove
846e75b893 feat(gemini): route gemini-3.1-flash-image identically to gemini-3-pro-image 2026-02-27 13:32:06 +08:00
maplelove
fc0257d6d9 refactor: consolidate duplicate UA and header scrubbing into shared misc functions 2026-02-27 10:57:13 +08:00
maplelove
f3c164d345 feat(antigravity): update to v1.19.5 with new models and Claude 4-6 migration 2026-02-27 10:34:27 +08:00
maplelove
4040b1e766 Merge remote-tracking branch 'upstream/dev' into dev
# Conflicts:
#	internal/runtime/executor/antigravity_executor.go
2026-02-27 10:29:50 +08:00
huang_usaki
3b4f9f43db feat(registry): add gemini-3.1-flash-image support 2026-02-27 10:20:46 +08:00
“cielhaidir”
37a09ecb23 feat(kiro): add new Kiro models definition 2026-02-27 10:18:59 +08:00
Luis Pater
0da34d3c2d Merge pull request #1668 from lyd123qw2008/fix/codex-usage-limit-retry-after
fix(codex): honor usage_limit_reached resets_at for retry_after
2026-02-27 06:01:44 +08:00
Luis Pater
74bf7eda8f Merge pull request #1686 from lyd123qw2008/fix/auth-refresh-concurrency-limit
fix(auth): limit auto-refresh concurrency to prevent refresh storms
2026-02-27 05:59:20 +08:00
Cyrus
9032042cfa feat(kiro): add Sonnet 4.6 model alias
- Add kiro-claude-sonnet-4-6 alias mapping to claude-sonnet-4-6
2026-02-27 01:02:21 +08:00
Cyrus
030bf5e6c7 feat(kiro): add IDC auth and endpoint improvements, redesign fingerprint system
- Add IAM Identity Center (IDC) authentication with CLI flags (--kiro-idc-login, --kiro-idc-start-url, --kiro-idc-region) and login flow
- Add ProfileArn auto-fetching in Execute/ExecuteStream for imported IDC accounts
- Simplify endpoint preference with map-based alias lookup and getAuthValue helper
- Redesign fingerprint as global singleton with external config and per-account deterministic generation
- Add StartURL and FingerprintConfig fields to Kiro config
- Add AgentContinuationID/AgentTaskType support in Kiro translators
- Add comprehensive tests for executor, fingerprint, SSO OIDC, and AWS helpers
- Add CLI login documentation to README
2026-02-27 00:58:03 +08:00
Luis Pater
d3100085b0 Merge pull request #392 from router-for-me/plus
v6.8.30
2026-02-26 23:16:26 +08:00
Luis Pater
f481d25133 Merge branch 'main' into plus 2026-02-26 23:16:17 +08:00
Luis Pater
8c6c90da74 fix(registry): clean up outdated model definitions in static data 2026-02-26 23:12:40 +08:00
Luis Pater
24bcfd9c03 Merge pull request #1699 from 123hi123/fix/antigravity-primary-model-fallback
fix(antigravity): keep primary model list and backfill empty auths
2026-02-26 04:28:29 +08:00
Luis Pater
816fb4c5da Merge pull request #1682 from sususu98/fix/tool-result-image-parts
fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
2026-02-25 23:14:35 +08:00
Luis Pater
c1bb77c7c9 Merge pull request #291 from howarddong711/feat/copilot-email-name
feat(copilot): fetch and persist user email and display name on login
2026-02-25 22:23:25 +08:00
Luis Pater
6bcac3a55a Merge branch 'router-for-me:main' into main 2026-02-25 22:21:31 +08:00
Howard Dong
fc346f4537 fix(copilot): add username fallback and consistent file name prefix
- Add 'github-user' fallback in WaitForAuthorization when FetchUserInfo
  returns empty Login (fixes malformed 'github-copilot-.json' filenames)
- Standardize Web API file name to 'github-copilot-<user>.json' to match
  CLI path convention (was 'github-<user>.json')

Addresses Gemini Code Assist review comments on PR #291.
2026-02-25 17:17:51 +08:00
Howard Dong
43e531a3b6 feat(copilot): fetch and persist user email and display name on login
- Expand OAuth scope to include read:user for full profile access
- Add GitHubUserInfo struct with Login, Email, Name fields
- Update FetchUserInfo to return complete user profile
- Add Email and Name fields to CopilotTokenStorage and CopilotAuthBundle
- Fix provider string bug: 'github' -> 'github-copilot' in auth_files.go
- Fix semantic bug: email field was storing username
- Update Label to prefer email over username in both CLI and Web API paths
- Add 9 unit tests covering new functionality
2026-02-25 17:09:40 +08:00
Luis Pater
d24ea4ce2a Merge pull request #1664 from ciberponk/pr/responses-compaction-compat
feat: add codex responses compatibility for compaction payloads
2026-02-25 01:21:59 +08:00
Luis Pater
2c30c981ae Merge pull request #1687 from lyd123qw2008/fix/codex-refresh-token-reused-no-retry
fix(codex): stop retrying refresh_token_reused errors
2026-02-25 01:19:30 +08:00
Luis Pater
aa1da8a858 Merge pull request #1685 from lyd123qw2008/fix/auth-auto-refresh-interval
fix(auth): respect configured auto-refresh interval
2026-02-25 01:13:47 +08:00
Luis Pater
f1e9a787d7 Merge pull request #1676 from piexian/feat/qwen-quota-handling-clean
feat(qwen): add rate limiting and quota error handling
2026-02-25 01:07:55 +08:00
Luis Pater
4eeec297de Merge pull request #288 from router-for-me/plus
v6.8.27
2026-02-25 01:04:57 +08:00
Luis Pater
77cc4ce3a0 Merge branch 'main' into plus 2026-02-25 01:04:15 +08:00
Luis Pater
37dfea1d3f Merge pull request #287 from possible055/main
fix(kiro): support OR-group field matching in truncation detector
2026-02-25 01:02:49 +08:00
Luis Pater
e6626c672a Merge pull request #269 from ClubWeGo/fix/filterOrphanedToolResults
fix: filter out orphaned tool results from history and current context
2026-02-25 01:02:11 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
apparition
c785c1a3ca fix(kiro): support OR-group field matching in truncation detector
- Change RequiredFieldsByTool value type from []string to [][]string
- Outer slice = AND (all groups required); inner slice = OR (any one satisfies)
- Fix Bash entry to accept "cmd" or "command", resolving soft-truncation loop
- Update findMissingRequiredFields logic and inline docs accordingly
2026-02-24 22:48:05 +08:00
comalot
514ae341c8 fix(antigravity): deep copy cached model metadata 2026-02-24 20:14:01 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
comalot
8ce07f38dd fix(antigravity): keep primary model list and backfill empty auths 2026-02-24 16:16:44 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
lyd123qw2008
3b3e0d1141 test(codex): log non-retryable refresh error and cover single-attempt behavior 2026-02-23 22:41:33 +08:00
lyd123qw2008
7acd428507 fix(codex): stop retrying refresh_token_reused errors 2026-02-23 22:31:30 +08:00
lyd123qw2008
0aaf177640 fix(auth): limit auto-refresh concurrency to prevent refresh storms 2026-02-23 22:28:41 +08:00
lyd123qw2008
450d1227bd fix(auth): respect configured auto-refresh interval 2026-02-23 22:07:50 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
Darley
6e634fe3f9 fix: filter out orphaned tool results from history and current context 2026-02-23 14:33:59 +08:00
sususu98
4e26182d14 fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
Move base64 image data from Claude tool_result into functionResponse.parts
as inlineData instead of outer sibling parts, preventing context bloat.
Unify all inlineData field naming to camelCase mimeType across Claude,
OpenAI, and Gemini translators. Add comprehensive edge case tests and
Gemini-side regression test for functionResponse.parts preservation.
2026-02-23 13:38:21 +08:00
maplelove
8f97a5f77c feat(registry): expose input modalities, token limits, and generation methods for Antigravity models 2026-02-23 13:33:51 +08:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
piexian
3b421c8181 feat(qwen): add rate limiting and quota error handling
- Add 60 requests/minute rate limiting per credential using sliding window
- Detect insufficient_quota errors and set cooldown until next day (Beijing time)
- Map quota errors (HTTP 403/429) to 429 with retryAfter for conductor integration
- Cache Beijing timezone at package level to avoid repeated syscalls
- Add redactAuthID function to protect credentials in logs
- Extract wrapQwenError helper to consolidate error handling
2026-02-23 00:38:46 +08:00
maplelove
2a4d3e60f3 Merge remote-tracking branch 'upstream/dev' into dev 2026-02-23 00:01:47 +08:00
maplelove
8b5af2ab84 fix(executor): match real Antigravity OAuth UA, remove redundant header scrubbing on new requests 2026-02-22 23:20:12 +08:00
maplelove
d887716ebd refactor(executor): switch HttpRequest to whitelist-based header filtering 2026-02-22 21:00:12 +08:00
maplelove
5dc1848466 feat(scrub): add comprehensive browser fingerprint and client identity header scrubbing 2026-02-22 20:51:00 +08:00
maplelove
9491517b26 fix(executor): use singleton transport to prevent OOM from connection pool leaks 2026-02-22 20:17:30 +08:00
maplelove
9370b5bd04 fix(executor): completely scrub all proxy tracing headers in executor 2026-02-22 19:43:10 +08:00
maplelove
abb51a0d93 fix(executor): correctly disable http2 ALPN in Antigravity client to resolve connection reset errors 2026-02-22 19:23:48 +08:00
maplelove
c8d809131b fix(executor): improve antigravity reverse proxy emulation
- force http/1.1 instead of http/2

- explicit connection close

- strip proxy headers X-Forwarded-For and X-Real-IP

- add project id to fetch models payload
2026-02-22 18:41:58 +08:00
maplelove
dd71c73a9f fix: align gemini-cli upstream communication headers
Removed legacy Client-Metadata and explicit API-Client headers. Dynamically generating accurate User-Agent strings matching the official cli.
2026-02-22 17:07:17 +08:00
fan
afc8a0f9be refactor: simplify context_management compatibility handling 2026-02-21 22:20:48 +08:00
lyd123qw2008
a99522224f refactor(codex): make retry-after parsing deterministic for tests 2026-02-21 14:13:38 +08:00
lyd123qw2008
f5d46b9ca2 fix(codex): honor usage_limit_reached resets_at for retry_after 2026-02-21 13:50:23 +08:00
ciberponk
d693d7993b feat: support responses compaction payload compatibility for codex translator 2026-02-21 12:56:10 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
113 changed files with 7685 additions and 1797 deletions

View File

@@ -31,6 +31,7 @@ bin/*
.agent/*
.agents/*
.opencode/*
.idea/*
.bmad/*
_bmad/*
_bmad-output/*

1
.gitignore vendored
View File

@@ -44,6 +44,7 @@ GEMINI.md
.agents/*
.agents/*
.opencode/*
.idea/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -10,23 +10,59 @@ The Plus release stays in lockstep with the mainline features.
## Differences from the Mainline
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
[![z.ai](https://assets.router-for.me/english-5.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
## New Features (Plus Enhanced)
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
- **Device Fingerprint**: Device fingerprint generation for enhanced security
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
- **Usage Checker**: Real-time usage monitoring and quota management
- **Model Converter**: Unified model name conversion across providers
- **UTF-8 Stream Processing**: Improved streaming response handling
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & GLM-5 Only Available for Pro Usersmodel across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
## Kiro Authentication
### CLI Login
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
**AWS Builder ID** (recommended):
```bash
# Device code flow
./CLIProxyAPI --kiro-aws-login
# Authorization code flow
./CLIProxyAPI --kiro-aws-authcode
```
**Import token from Kiro IDE:**
```bash
./CLIProxyAPI --kiro-import
```
To get a token from Kiro IDE:
1. Open Kiro IDE and login with Google (or GitHub)
2. Find the token file: `~/.kiro/kiro-auth-token.json`
3. Run: `./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC):**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# Specify region
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**Additional flags:**
| Flag | Description |
|------|-------------|
| `--no-browser` | Don't open browser automatically, print URL instead |
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
### Web-based OAuth Login
Access the Kiro OAuth web interface at:

View File

@@ -10,22 +10,58 @@
## 与主线版本版本差异
- 新增 GitHub Copilot 支持OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
[![bigmodel.cn](https://assets.router-for.me/chinese-5.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
## 新增功能 (Plus 增强版)
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
- **请求限流器**: 内置请求限流,防止 API 滥用
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
- **监控指标**: 请求指标收集,用于监控和调试
- **设备指纹**: 设备指纹生成,增强安全性
- **冷却管理**: 智能冷却机制,应对 API 速率限制
- **用量检查器**: 实时用量监控和配额管理
- **模型转换器**: 跨供应商的统一模型名称转换
- **UTF-8 流处理**: 改进的流式响应处理
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7受限于算力目前仅限Pro用户开放为开发者提供顶尖的编码体验。
## Kiro 认证
智谱AI为本产品提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
### 命令行登录
> **注意:** 由于 AWS Cognito 限制Google/GitHub 登录不可用于第三方应用。
**AWS Builder ID**(推荐):
```bash
# 设备码流程
./CLIProxyAPI --kiro-aws-login
# 授权码流程
./CLIProxyAPI --kiro-aws-authcode
```
**从 Kiro IDE 导入令牌:**
```bash
./CLIProxyAPI --kiro-import
```
获取令牌步骤:
1. 打开 Kiro IDE使用 Google或 GitHub登录
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
3. 运行:`./CLIProxyAPI --kiro-import`
**AWS IAM Identity Center (IDC)**
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
# 指定区域
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**附加参数:**
| 参数 | 说明 |
|------|------|
| `--no-browser` | 不自动打开浏览器,打印 URL |
| `--no-incognito` | 使用已有浏览器会话Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
| `--kiro-idc-start-url` | IDC Start URL`--kiro-idc-login` 必需) |
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1` |
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
### 网页端 OAuth 登录

View File

@@ -72,6 +72,7 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
@@ -86,6 +87,10 @@ func main() {
var kiroAWSLogin bool
var kiroAWSAuthCode bool
var kiroImport bool
var kiroIDCLogin bool
var kiroIDCStartURL string
var kiroIDCRegion string
var kiroIDCFlow string
var githubCopilotLogin bool
var projectID string
var vertexImport string
@@ -99,6 +104,7 @@ func main() {
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
@@ -115,6 +121,10 @@ func main() {
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
@@ -502,6 +512,9 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)
@@ -521,24 +534,34 @@ func main() {
// Note: This config mutation is safe - auth commands exit after completion
// and don't share config with StartService (which is in the else branch)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroLogin(cfg, options)
} else if kiroGoogleLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroGoogleLogin(cfg, options)
} else if kiroAWSLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSLogin(cfg, options)
} else if kiroAWSAuthCode {
// For Kiro auth with authorization code flow (better UX)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
} else if kiroImport {
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroImport(cfg, options)
} else if kiroIDCLogin {
// For Kiro IDC auth, default to incognito mode for multi-account support
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {

View File

@@ -80,6 +80,10 @@ passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
# Maximum number of different credentials to try for one failed request.
# Set to 0 to keep legacy behavior (try all available credentials).
max-retry-credentials: 0
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30
@@ -179,6 +183,8 @@ nonstream-keepalive-interval: 0
#kiro:
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
# agent-task-type: "" # optional: "vibe" or empty (API default)
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
# - access-token: "aoaAAAAA..." # or provide tokens directly
# refresh-token: "aorAAAAA..."
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."

View File

@@ -48,14 +48,11 @@ import (
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
const (
anthropicCallbackPort = 54545
geminiCallbackPort = 8085
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
geminiCLIApiClient = "gl-node/22.17.0"
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
anthropicCallbackPort = 54545
geminiCallbackPort = 8085
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
)
type callbackForwarder struct {
@@ -412,6 +409,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
if !auth.LastRefreshedAt.IsZero() {
entry["last_refresh"] = auth.LastRefreshedAt
}
if !auth.NextRetryAfter.IsZero() {
entry["next_retry_after"] = auth.NextRetryAfter
}
if path != "" {
entry["path"] = path
entry["source"] = "file"
@@ -951,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1100,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1358,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1503,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1667,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1722,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1798,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -1917,8 +1929,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
@@ -1932,7 +1942,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
@@ -1944,9 +1954,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
@@ -1955,18 +1969,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
}
@@ -1980,7 +2002,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
@@ -2359,9 +2381,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
return fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
@@ -2431,7 +2451,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
@@ -2452,7 +2472,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo = httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
@@ -2521,6 +2541,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -15,6 +15,7 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
@@ -77,6 +78,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove proxy, client identity, and browser fingerprint headers
misc.ScrubProxyAndFingerprintHeaders(req)
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.

View File

@@ -52,6 +52,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -59,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
configDir := filepath.Dir(configPath)
if base := util.WritablePath(); base != "" {
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
}
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
logsDir := logging.ResolveLogDirectory(cfg)
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
}
// WithMiddleware appends additional Gin middleware during server construction.
@@ -112,6 +111,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -252,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.applyAccessConfig(nil, cfg)
if authManager != nil {
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
}
managementasset.SetCurrentConfig(cfg)
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
@@ -263,6 +269,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes
@@ -935,7 +944,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if s.handlers != nil && s.handlers.AuthManager != nil {
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
}
// Update log level dynamically when debug flag changes

View File

@@ -7,9 +7,11 @@ import (
"path/filepath"
"strings"
"testing"
"time"
gin "github.com/gin-gonic/gin"
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
})
}
}
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
t.Setenv("WRITABLE_PATH", "")
t.Setenv("writable_path", "")
originalWD, errGetwd := os.Getwd()
if errGetwd != nil {
t.Fatalf("failed to get current working directory: %v", errGetwd)
}
tmpDir := t.TempDir()
if errChdir := os.Chdir(tmpDir); errChdir != nil {
t.Fatalf("failed to switch working directory: %v", errChdir)
}
defer func() {
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
t.Fatalf("failed to restore working directory: %v", errChdirBack)
}
}()
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
}
configDir := filepath.Join(tmpDir, "config")
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
}
configPath := filepath.Join(configDir, "config.yaml")
authDir := filepath.Join(tmpDir, "auth")
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
}
cfg := &proxyconfig.Config{
SDKConfig: proxyconfig.SDKConfig{
RequestLog: false,
},
AuthDir: authDir,
ErrorLogsMaxFiles: 10,
}
logger := defaultRequestLoggerFactory(cfg, configPath)
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
if !ok {
t.Fatalf("expected *FileRequestLogger, got %T", logger)
}
errLog := fileLogger.LogRequestWithOptions(
"/v1/chat/completions",
http.MethodPost,
map[string][]string{"Content-Type": []string{"application/json"}},
[]byte(`{"input":"hello"}`),
http.StatusBadGateway,
map[string][]string{"Content-Type": []string{"application/json"}},
[]byte(`{"error":"upstream failure"}`),
nil,
nil,
nil,
true,
"issue-1711",
time.Now(),
time.Now(),
)
if errLog != nil {
t.Fatalf("failed to write forced error request log: %v", errLog)
}
authLogsDir := filepath.Join(authDir, "logs")
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
if errReadAuthDir != nil {
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
}
foundErrorLogInAuthDir := false
for _, entry := range authEntries {
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
foundErrorLogInAuthDir = true
break
}
}
if !foundErrorLogInAuthDir {
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
}
configLogsDir := filepath.Join(configDir, "logs")
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
}
for _, entry := range configEntries {
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
}
}
}

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -15,7 +15,7 @@ import (
"golang.org/x/net/proxy"
)
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
type utlsRoundTripper struct {
// mu protects the connections map and pending map
@@ -100,7 +100,9 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
return h2Conn, nil
}
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := t.dialer.Dial("tcp", addr)
if err != nil {
@@ -108,7 +110,7 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
}
tlsConfig := &tls.Config{ServerName: host}
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
@@ -156,7 +158,7 @@ func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
}
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
// for Anthropic domains by using utls with Firefox fingerprint.
// for Anthropic domains by using utls with Chrome fingerprint.
// It accepts optional SDK configuration for proxy settings.
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
return &http.Client{

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {

View File

@@ -0,0 +1,44 @@
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
}
// Fetch the GitHub username
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if err != nil {
log.Warnf("copilot: failed to fetch user info: %v", err)
username = "unknown"
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
return &CopilotAuthBundle{
TokenData: tokenData,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
}, nil
}
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
return false, "", nil
}
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
if err != nil {
return false, "", err
}
return true, username, nil
return true, userInfo.Login, nil
}
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
Username: bundle.Username,
Email: bundle.Email,
Name: bundle.Name,
Type: "github-copilot",
}
}

View File

@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("scope", "user:email")
data.Set("scope", "read:user user:email")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
}, nil
}
// FetchUserInfo retrieves the GitHub username for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
// GitHubUserInfo holds GitHub user profile information.
type GitHubUserInfo struct {
// Login is the GitHub username.
Login string
// Email is the primary email address (may be empty if not public).
Email string
// Name is the display name.
Name string
}
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
if accessToken == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
resp, err := c.httpClient.Do(req)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var userInfo struct {
var raw struct {
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
if userInfo.Login == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
if raw.Login == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
}
return userInfo.Login, nil
return GitHubUserInfo{
Login: raw.Login,
Email: raw.Email,
Name: raw.Name,
}, nil
}

View File

@@ -0,0 +1,213 @@
package copilot
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// roundTripFunc lets us inject a custom transport for testing.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
// regardless of the original URL host.
func newTestClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.URL.Scheme = "http"
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
return srv.Client().Transport.RoundTrip(req2)
}),
}
}
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
func TestFetchUserInfo_FullProfile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"login": "octocat",
"email": "octocat@github.com",
"name": "The Octocat",
})
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "octocat" {
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
}
if info.Email != "octocat@github.com" {
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
}
if info.Name != "The Octocat" {
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
}
}
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// GitHub returns null for private emails.
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "privateuser" {
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
}
if info.Email != "" {
t.Errorf("Email: got %q, want empty string", info.Email)
}
if info.Name != "Private User" {
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
}
}
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
func TestFetchUserInfo_EmptyToken(t *testing.T) {
client := &DeviceFlowClient{httpClient: http.DefaultClient}
_, err := client.FetchUserInfo(context.Background(), "")
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
}
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "test-token")
if err == nil {
t.Fatal("expected error for empty login, got nil")
}
}
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
func TestFetchUserInfo_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("expected error for 401 response, got nil")
}
}
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
TokenType: "bearer",
Scope: "read:user user:email",
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
if _, ok := out[key]; !ok {
t.Errorf("expected key %q in JSON output, not found", key)
}
}
if out["email"] != "octocat@github.com" {
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
}
if out["name"] != "The Octocat" {
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
}
}
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
Username: "octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if _, ok := out["email"]; ok {
t.Error("email key should be omitted when empty (omitempty), but was present")
}
if _, ok := out["name"]; ok {
t.Error("name key should be omitted when empty (omitempty), but was present")
}
}
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
bundle := &CopilotAuthBundle{
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if bundle.Email != "octocat@github.com" {
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
}
if bundle.Name != "The Octocat" {
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
}
}
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
func TestGitHubUserInfo_Struct(t *testing.T) {
info := GitHubUserInfo{
Login: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if info.Login == "" || info.Email == "" || info.Name == "" {
t.Error("GitHubUserInfo fields should not be empty")
}
}

View File

@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
ExpiresAt string `json:"expires_at,omitempty"`
// Username is the GitHub username associated with this token.
Username string `json:"username"`
// Email is the GitHub email address associated with this token.
Email string `json:"email,omitempty"`
// Name is the GitHub display name associated with this token.
Name string `json:"name,omitempty"`
// Type indicates the authentication provider type, always "github-copilot" for this storage.
Type string `json:"type"`
}
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
TokenData *CopilotTokenData
// Username is the GitHub username.
Username string
// Email is the GitHub email address.
Email string
// Name is the GitHub display name.
Name string
}
// DeviceCodeResponse represents GitHub's device code response.

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -7,10 +7,13 @@ import (
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
@@ -47,7 +50,7 @@ type KiroTokenData struct {
Email string `json:"email,omitempty"`
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
StartURL string `json:"startUrl,omitempty"`
// Region is the AWS region for IDC authentication (only for IDC auth method)
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
}
@@ -520,3 +523,159 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string {
// Priority 3: Fallback to authMethod only with sequence
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
}
// DefaultKiroRegion is the fallback region when none is specified.
const DefaultKiroRegion = "us-east-1"
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
func GetCodeWhispererLegacyEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://codewhisperer." + region + ".amazonaws.com"
}
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
type ProfileARN struct {
// Raw is the original ARN string
Raw string
// Partition is the AWS partition (aws)
Partition string
// Service is the AWS service name (codewhisperer)
Service string
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
Region string
// AccountID is the AWS account ID
AccountID string
// ResourceType is the resource type (profile)
ResourceType string
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
ResourceID string
}
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
func ParseProfileARN(arn string) *ProfileARN {
if arn == "" {
return nil
}
// ARN format: arn:partition:service:region:account-id:resource
// Minimum 6 parts separated by ":"
parts := strings.Split(arn, ":")
if len(parts) < 6 {
log.Warnf("invalid ARN format: %s", arn)
return nil
}
// Validate ARN prefix
if parts[0] != "arn" {
return nil
}
// Validate partition
partition := parts[1]
if partition == "" {
return nil
}
// Validate service is codewhisperer
service := parts[2]
if service != "codewhisperer" {
return nil
}
// Validate region format (must contain "-")
region := parts[3]
if region == "" || !strings.Contains(region, "-") {
return nil
}
// Account ID
accountID := parts[4]
// Parse resource (format: resource-type/resource-id)
// Join remaining parts in case resource contains ":"
resource := strings.Join(parts[5:], ":")
resourceType := ""
resourceID := ""
if idx := strings.Index(resource, "/"); idx > 0 {
resourceType = resource[:idx]
resourceID = resource[idx+1:]
} else {
resourceType = resource
}
return &ProfileARN{
Raw: arn,
Partition: partition,
Service: service,
Region: region,
AccountID: accountID,
ResourceType: resourceType,
ResourceID: resourceID,
}
}
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
// If region is empty, defaults to us-east-1.
func GetKiroAPIEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://q." + region + ".amazonaws.com"
}
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
// Returns default us-east-1 endpoint if region cannot be extracted.
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
region := ExtractRegionFromProfileArn(profileArn)
return GetKiroAPIEndpoint(region)
}
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
// Returns empty string if ARN is invalid or region cannot be extracted.
func ExtractRegionFromProfileArn(profileArn string) string {
parsed := ParseProfileARN(profileArn)
if parsed == nil {
return ""
}
return parsed.Region
}
// ExtractRegionFromMetadata extracts API region from auth metadata.
// Priority: api_region > profile_arn > DefaultKiroRegion
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
if metadata == nil {
return DefaultKiroRegion
}
// Priority 1: Explicit api_region override
if r, ok := metadata["api_region"].(string); ok && r != "" {
return r
}
// Priority 2: Extract from ProfileARN
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
return region
}
}
return DefaultKiroRegion
}
func buildURL(endpoint, path string, queryParams map[string]string) string {
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
if len(queryParams) > 0 {
values := url.Values{}
for key, value := range queryParams {
if value == "" {
continue
}
values.Set(key, value)
}
if encoded := values.Encode(); encoded != "" {
fullURL = fullURL + "?" + encoded
}
}
return fullURL
}

View File

@@ -19,15 +19,8 @@ import (
)
const (
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
// for their respective API operations.
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
pathGetUsageLimits = "getUsageLimits"
pathListAvailableModels = "ListAvailableModels"
)
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
@@ -35,7 +28,6 @@ const (
// and communicating with the CodeWhisperer API.
type KiroAuth struct {
httpClient *http.Client
endpoint string
}
// NewKiroAuth creates a new Kiro authentication service.
@@ -49,7 +41,6 @@ type KiroAuth struct {
func NewKiroAuth(cfg *config.Config) *KiroAuth {
return &KiroAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
@@ -110,33 +101,30 @@ func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
return time.Now().After(expiresAt)
}
// makeRequest sends a request to the CodeWhisperer API.
// This is an internal method for making authenticated API calls.
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
//
// Parameters:
// - ctx: The context for the request
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
// - accessToken: The OAuth access token
// - payload: The request payload
// - path: The API path (e.g., "getUsageLimits")
// - tokenData: The token data containing access token, refresh token, and profile ARN
// - queryParams: Query parameters to add to the URL
//
// Returns:
// - []byte: The response body
// - error: An error if the request fails
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
// Get endpoint from profileArn (defaults to us-east-1 if empty)
profileArn := queryParams["profileArn"]
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
url := buildURL(endpoint, path, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", target)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := k.httpClient.Do(req)
if err != nil {
@@ -171,13 +159,13 @@ func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken s
// - *KiroUsageInfo: The usage information
// - error: An error if the request fails
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
if err != nil {
return nil, err
}
@@ -221,12 +209,12 @@ func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData)
// - []*KiroModel: The list of available models
// - error: An error if the request fails
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
}
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package kiro
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
)
@@ -217,7 +218,8 @@ func TestGenerateTokenFileName(t *testing.T) {
tests := []struct {
name string
tokenData *KiroTokenData
expected string
exact string // exact match (for cases with email)
prefix string // prefix match (for cases without email, where sequence is appended)
}{
{
name: "IDC with email",
@@ -226,7 +228,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user@example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-example-com.json",
exact: "kiro-idc-user-example-com.json",
},
{
name: "IDC without email but with startUrl",
@@ -235,7 +237,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-d-1234567890.json",
prefix: "kiro-idc-d-1234567890-",
},
{
name: "IDC with company name in startUrl",
@@ -244,7 +246,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://my-company.awsapps.com/start",
},
expected: "kiro-idc-my-company.json",
prefix: "kiro-idc-my-company-",
},
{
name: "IDC without email and without startUrl",
@@ -253,7 +255,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "",
},
expected: "kiro-idc.json",
prefix: "kiro-idc-",
},
{
name: "Builder ID with email",
@@ -262,7 +264,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user@gmail.com",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id-user-gmail-com.json",
exact: "kiro-builder-id-user-gmail-com.json",
},
{
name: "Builder ID without email",
@@ -271,7 +273,7 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id.json",
prefix: "kiro-builder-id-",
},
{
name: "Social auth with email",
@@ -279,7 +281,7 @@ func TestGenerateTokenFileName(t *testing.T) {
AuthMethod: "google",
Email: "user@gmail.com",
},
expected: "kiro-google-user-gmail-com.json",
exact: "kiro-google-user-gmail-com.json",
},
{
name: "Empty auth method",
@@ -287,7 +289,7 @@ func TestGenerateTokenFileName(t *testing.T) {
AuthMethod: "",
Email: "",
},
expected: "kiro-unknown.json",
prefix: "kiro-unknown-",
},
{
name: "Email with special characters",
@@ -296,16 +298,454 @@ func TestGenerateTokenFileName(t *testing.T) {
Email: "user.name+tag@sub.example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-name+tag-sub-example-com.json",
exact: "kiro-idc-user-name+tag-sub-example-com.json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateTokenFileName(tt.tokenData)
if result != tt.expected {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
if tt.exact != "" {
if result != tt.exact {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
}
} else if tt.prefix != "" {
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
}
}
})
}
}
func TestParseProfileARN(t *testing.T) {
tests := []struct {
name string
arn string
expected *ProfileARN
}{
{
name: "Empty ARN",
arn: "",
expected: nil,
},
{
name: "Invalid format - too few parts",
arn: "arn:aws:codewhisperer",
expected: nil,
},
{
name: "Invalid prefix - not arn",
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Invalid service - not codewhisperer",
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: nil,
},
{
name: "Invalid region - no hyphen",
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty partition",
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty region",
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
expected: nil,
},
{
name: "Valid ARN - us-east-1",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABCDEFGHIJKL",
},
},
{
name: "Valid ARN - ap-southeast-1",
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
Partition: "aws",
Service: "codewhisperer",
Region: "ap-southeast-1",
AccountID: "987654321098",
ResourceType: "profile",
ResourceID: "ZYXWVUTSRQ",
},
},
{
name: "Valid ARN - eu-west-1",
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
Partition: "aws",
Service: "codewhisperer",
Region: "eu-west-1",
AccountID: "111222333444",
ResourceType: "profile",
ResourceID: "PROFILE123",
},
},
{
name: "Valid ARN - aws-cn partition",
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
expected: &ProfileARN{
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
Partition: "aws-cn",
Service: "codewhisperer",
Region: "cn-north-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "CHINAID",
},
},
{
name: "Valid ARN - resource without slash",
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
Partition: "aws",
Service: "codewhisperer",
Region: "us-west-2",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "",
},
},
{
name: "Valid ARN - resource with colon",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABC:extra",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseProfileARN(tt.arn)
if tt.expected == nil {
if result != nil {
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
}
return
}
if result == nil {
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
return
}
if result.Raw != tt.expected.Raw {
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
}
if result.Partition != tt.expected.Partition {
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
}
if result.Service != tt.expected.Service {
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
}
if result.Region != tt.expected.Region {
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
}
if result.AccountID != tt.expected.AccountID {
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
}
if result.ResourceType != tt.expected.ResourceType {
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
}
if result.ResourceID != tt.expected.ResourceID {
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
}
})
}
}
func TestExtractRegionFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN",
profileArn: "",
expected: "",
},
{
name: "Invalid ARN",
profileArn: "invalid-arn",
expected: "",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "us-east-1",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "ap-southeast-1",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "eu-central-1",
},
{
name: "Non-codewhisperer ARN",
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://q.us-west-2.amazonaws.com",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "eu-west-1",
region: "eu-west-1",
expected: "https://q.eu-west-1.amazonaws.com",
},
{
name: "cn-north-1",
region: "cn-north-1",
expected: "https://q.cn-north-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN - defaults to us-east-1",
profileArn: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Invalid ARN - defaults to us-east-1",
profileArn: "invalid-arn",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "https://q.eu-central-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://codewhisperer.us-west-2.amazonaws.com",
},
{
name: "ap-northeast-1",
region: "ap-northeast-1",
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetCodeWhispererLegacyEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestExtractRegionFromMetadata(t *testing.T) {
tests := []struct {
name string
metadata map[string]interface{}
expected string
}{
{
name: "Nil metadata - defaults to us-east-1",
metadata: nil,
expected: "us-east-1",
},
{
name: "Empty metadata - defaults to us-east-1",
metadata: map[string]interface{}{},
expected: "us-east-1",
},
{
name: "Priority 1: api_region override",
metadata: map[string]interface{}{
"api_region": "eu-west-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
expected: "eu-west-1",
},
{
name: "Priority 2: profile_arn when api_region is empty",
metadata: map[string]interface{}{
"api_region": "",
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
expected: "ap-southeast-1",
},
{
name: "Priority 2: profile_arn when api_region is missing",
metadata: map[string]interface{}{
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
},
expected: "eu-central-1",
},
{
name: "Fallback: default when profile_arn is invalid",
metadata: map[string]interface{}{
"profile_arn": "invalid-arn",
},
expected: "us-east-1",
},
{
name: "Fallback: default when profile_arn is empty",
metadata: map[string]interface{}{
"profile_arn": "",
},
expected: "us-east-1",
},
{
name: "OIDC region is NOT used for API region",
metadata: map[string]interface{}{
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-east-1",
},
{
name: "api_region takes precedence over OIDC region",
metadata: map[string]interface{}{
"api_region": "us-west-2",
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-west-2",
},
{
name: "Non-string api_region is ignored",
metadata: map[string]interface{}{
"api_region": 123, // wrong type
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
},
expected: "ap-south-1",
},
{
name: "Non-string profile_arn is ignored",
metadata: map[string]interface{}{
"profile_arn": 123, // wrong type
},
expected: "us-east-1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromMetadata(tt.metadata)
if result != tt.expected {
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
}
})
}
}

View File

@@ -9,30 +9,23 @@ import (
"net/http"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
kiroVersion = "0.6.18"
)
// CodeWhispererClient handles CodeWhisperer API calls.
type CodeWhispererClient struct {
httpClient *http.Client
machineID string
}
// UsageLimitsResponse represents the getUsageLimits API response.
type UsageLimitsResponse struct {
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
UserInfo *UserInfo `json:"userInfo,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
UserInfo *UserInfo `json:"userInfo,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
}
// UserInfo contains user information from the API.
@@ -49,13 +42,13 @@ type SubscriptionInfo struct {
// UsageBreakdown contains usage details.
type UsageBreakdown struct {
UsageLimit *int `json:"usageLimit,omitempty"`
CurrentUsage *int `json:"currentUsage,omitempty"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
DisplayName string `json:"displayName,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
UsageLimit *int `json:"usageLimit,omitempty"`
CurrentUsage *int `json:"currentUsage,omitempty"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
DisplayName string `json:"displayName,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
}
// NewCodeWhispererClient creates a new CodeWhisperer client.
@@ -64,40 +57,34 @@ func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhisperer
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
if machineID == "" {
machineID = uuid.New().String()
}
return &CodeWhispererClient{
httpClient: client,
machineID: machineID,
}
}
// generateInvocationID generates a unique invocation ID.
func generateInvocationID() string {
return uuid.New().String()
}
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
// This is the recommended way to get user email after login.
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
queryParams := map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
}
// Determine endpoint based on profileArn region
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
if profileArn != "" {
queryParams["profileArn"] = profileArn
} else {
queryParams["isEmailRequired"] = "true"
}
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers to match Kiro IDE
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
req.Header.Set("User-Agent", userAgent)
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
req.Header.Set("Connection", "close")
accountKey := GetAccountKey(clientID, refreshToken)
setRuntimeHeaders(req, accessToken, accountKey)
log.Debugf("codewhisperer: GET %s", url)
@@ -128,8 +115,8 @@ func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken st
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
// This is more reliable than JWT parsing as it uses the official API.
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
resp, err := c.GetUsageLimits(ctx, accessToken)
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
if err != nil {
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
return ""
@@ -146,10 +133,10 @@ func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessT
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
// Method 1: Try CodeWhisperer API (most reliable)
cwClient := NewCodeWhispererClient(cfg, "")
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
if email != "" {
return email
}

View File

@@ -2,77 +2,105 @@ package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"runtime"
"slices"
"sync"
"time"
"github.com/google/uuid"
)
// Fingerprint 多维度指纹信息
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
type Fingerprint struct {
SDKVersion string // 1.0.20-1.0.27
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
RuntimeSDKVersion string // 1.0.x (runtime API)
StreamingSDKVersion string // 1.0.x (streaming API)
OSType string // darwin/windows/linux
OSVersion string // 10.0.22621
NodeVersion string // 18.x/20.x/22.x
KiroVersion string // 0.3.x-0.8.x
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string // SHA256
AcceptLanguage string
ScreenResolution string // 1920x1080
ColorDepth int // 24
HardwareConcurrency int // CPU 核心数
TimezoneOffset int
}
// FingerprintManager 指纹管理器
// FingerprintConfig holds external fingerprint overrides.
type FingerprintConfig struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
// FingerprintManager manages per-account fingerprint generation and caching.
type FingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
rng *rand.Rand
config *FingerprintConfig // External config (Optional)
}
var (
sdkVersions = []string{
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
// SDK versions
oidcSDKVersions = []string{
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
}
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
runtimeSDKVersions = []string{"1.0.0"}
// SDKVersions for generateAssistantResponse (streaming API)
streamingSDKVersions = []string{"1.0.27"}
// Valid OS types
osTypes = []string{"darwin", "windows", "linux"}
// OS versions
osVersions = map[string][]string{
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
}
// Node versions
nodeVersions = []string{
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
"20.18.0", "20.17.0", "20.16.0",
}
// Kiro IDE versions
kiroVersions = []string{
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
"0.10.32", "0.10.16", "0.10.10",
"0.9.47", "0.9.40", "0.9.2",
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
}
acceptLanguages = []string{
"en-US,en;q=0.9",
"en-GB,en;q=0.9",
"zh-CN,zh;q=0.9,en;q=0.8",
"zh-TW,zh;q=0.9,en;q=0.8",
"ja-JP,ja;q=0.9,en;q=0.8",
"ko-KR,ko;q=0.9,en;q=0.8",
"de-DE,de;q=0.9,en;q=0.8",
"fr-FR,fr;q=0.9,en;q=0.8",
}
screenResolutions = []string{
"1920x1080", "2560x1440", "3840x2160",
"1366x768", "1440x900", "1680x1050",
"2560x1600", "3440x1440",
}
colorDepths = []int{24, 32}
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
// Global singleton
globalFingerprintManager *FingerprintManager
globalFingerprintManagerOnce sync.Once
)
// NewFingerprintManager 创建指纹管理器
func GlobalFingerprintManager() *FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = NewFingerprintManager()
})
return globalFingerprintManager
}
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
GlobalFingerprintManager().SetConfig(cfg)
}
// SetConfig applies the config and clears the fingerprint cache.
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
fm.mu.Lock()
defer fm.mu.Unlock()
fm.config = cfg
// Clear cached fingerprints so they regenerate with the new config
fm.fingerprints = make(map[string]*Fingerprint)
}
func NewFingerprintManager() *FingerprintManager {
return &FingerprintManager{
fingerprints: make(map[string]*Fingerprint),
@@ -80,7 +108,7 @@ func NewFingerprintManager() *FingerprintManager {
}
}
// GetFingerprint 获取或生成 Token 关联的指纹
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
fm.mu.RLock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
@@ -101,97 +129,150 @@ func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
return fp
}
// generateFingerprint 生成新的指纹
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
osType := fm.randomChoice(osTypes)
osVersion := fm.randomChoice(osVersions[osType])
kiroVersion := fm.randomChoice(kiroVersions)
if fm.config != nil {
return fm.generateFromConfig(tokenKey)
}
return fm.generateRandom(tokenKey)
}
fp := &Fingerprint{
SDKVersion: fm.randomChoice(sdkVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: fm.randomChoice(nodeVersions),
KiroVersion: kiroVersion,
AcceptLanguage: fm.randomChoice(acceptLanguages),
ScreenResolution: fm.randomChoice(screenResolutions),
ColorDepth: fm.randomIntChoice(colorDepths),
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
// generateFromConfig uses config values, falling back to random for empty fields.
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
cfg := fm.config
// Helper: config value or random selection
configOrRandom := func(configVal string, choices []string) string {
if configVal != "" {
return configVal
}
return choices[fm.rng.Intn(len(choices))]
}
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
return fp
osType := cfg.OSType
if osType == "" {
osType = runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[fm.rng.Intn(len(osTypes))]
}
}
osVersion := cfg.OSVersion
if osVersion == "" {
if versions, ok := osVersions[osType]; ok {
osVersion = versions[fm.rng.Intn(len(versions))]
}
}
kiroHash := cfg.KiroHash
if kiroHash == "" {
hash := sha256.Sum256([]byte(tokenKey))
kiroHash = hex.EncodeToString(hash[:])
}
return &Fingerprint{
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
KiroHash: kiroHash,
}
}
// generateKiroHash 生成 Kiro Hash
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
// Use accountKey hash as seed for deterministic random selection
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
return &Fingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersion,
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: hex.EncodeToString(hash[:]),
}
}
// randomChoice 随机选择字符串
func (fm *FingerprintManager) randomChoice(choices []string) string {
return choices[fm.rng.Intn(len(choices))]
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
func GenerateAccountKey(seed string) string {
hash := sha256.Sum256([]byte(seed))
return hex.EncodeToString(hash[:8])
}
// randomIntChoice 随机选择整数
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
return choices[fm.rng.Intn(len(choices))]
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
func GetAccountKey(clientID, refreshToken string) string {
// 1. Prefer ClientID
if clientID != "" {
return GenerateAccountKey(clientID)
}
// 2. Fallback to RefreshToken
if refreshToken != "" {
return GenerateAccountKey(refreshToken)
}
// 3. Random fallback
return GenerateAccountKey(uuid.New().String())
}
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
req.Header.Set("Accept-Language", fp.AcceptLanguage)
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
}
// RemoveFingerprint 移除 Token 关联的指纹
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
fm.mu.Lock()
defer fm.mu.Unlock()
delete(fm.fingerprints, tokenKey)
}
// Count 返回当前管理的指纹数量
func (fm *FingerprintManager) Count() int {
fm.mu.RLock()
defer fm.mu.RUnlock()
return len(fm.fingerprints)
}
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildAmzUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.SDKVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func SetOIDCHeaders(req *http.Request) {
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
}
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
fp.KiroVersion, machineID))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
}

View File

@@ -2,6 +2,8 @@ package kiro
import (
"net/http"
"runtime"
"strings"
"sync"
"testing"
)
@@ -26,8 +28,14 @@ func TestGetFingerprint_NewToken(t *testing.T) {
if fp == nil {
t.Fatal("expected non-nil Fingerprint")
}
if fp.SDKVersion == "" {
t.Error("expected non-empty SDKVersion")
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.RuntimeSDKVersion == "" {
t.Error("expected non-empty RuntimeSDKVersion")
}
if fp.StreamingSDKVersion == "" {
t.Error("expected non-empty StreamingSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
@@ -44,18 +52,6 @@ func TestGetFingerprint_NewToken(t *testing.T) {
if fp.KiroHash == "" {
t.Error("expected non-empty KiroHash")
}
if fp.AcceptLanguage == "" {
t.Error("expected non-empty AcceptLanguage")
}
if fp.ScreenResolution == "" {
t.Error("expected non-empty ScreenResolution")
}
if fp.ColorDepth == 0 {
t.Error("expected non-zero ColorDepth")
}
if fp.HardwareConcurrency == 0 {
t.Error("expected non-zero HardwareConcurrency")
}
}
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
@@ -78,72 +74,18 @@ func TestGetFingerprint_DifferentTokens(t *testing.T) {
}
}
func TestRemoveFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fm.GetFingerprint("token1")
if fm.Count() != 1 {
t.Fatalf("expected count 1, got %d", fm.Count())
}
fm.RemoveFingerprint("token1")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestRemoveFingerprint_NonExistent(t *testing.T) {
fm := NewFingerprintManager()
fm.RemoveFingerprint("nonexistent")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestCount(t *testing.T) {
fm := NewFingerprintManager()
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
fm.GetFingerprint("token1")
fm.GetFingerprint("token2")
fm.GetFingerprint("token3")
if fm.Count() != 3 {
t.Errorf("expected count 3, got %d", fm.Count())
}
}
func TestApplyToRequest(t *testing.T) {
func TestBuildUserAgent(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
ua := fp.BuildUserAgent()
if ua == "" {
t.Error("expected non-empty User-Agent")
}
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
t.Error("X-Kiro-SDK-Version header mismatch")
}
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
t.Error("X-Kiro-OS-Type header mismatch")
}
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
t.Error("X-Kiro-OS-Version header mismatch")
}
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
t.Error("X-Kiro-Node-Version header mismatch")
}
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
t.Error("X-Kiro-Version header mismatch")
}
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
t.Error("X-Kiro-Hash header mismatch")
}
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
t.Error("Accept-Language header mismatch")
}
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
t.Error("X-Screen-Resolution header mismatch")
amzUA := fp.BuildAmzUserAgent()
if amzUA == "" {
t.Error("expected non-empty X-Amz-User-Agent")
}
}
@@ -166,6 +108,33 @@ func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
}
}
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
fm := NewFingerprintManager()
// Set config with empty OSType to trigger runtime.GOOS fallback
fm.SetConfig(&FingerprintConfig{
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
})
fp := fm.GetFingerprint("test-token")
// Expected OS type based on runtime.GOOS mapping
var expectedOS string
switch runtime.GOOS {
case "darwin":
expectedOS = "darwin"
case "windows":
expectedOS = "windows"
default:
expectedOS = "linux"
}
if fp.OSType != expectedOS {
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
expectedOS, runtime.GOOS, fp.OSType)
}
}
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
fm := NewFingerprintManager()
const numGoroutines = 100
@@ -174,22 +143,18 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
for i := range numGoroutines {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
for j := range numOperations {
tokenKey := "token" + string(rune('a'+id%26))
switch j % 4 {
switch j % 2 {
case 0:
fm.GetFingerprint(tokenKey)
case 1:
fm.Count()
case 2:
fp := fm.GetFingerprint(tokenKey)
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
case 3:
fm.RemoveFingerprint(tokenKey)
_ = fp.BuildUserAgent()
_ = fp.BuildAmzUserAgent()
}
}
}(i)
@@ -198,16 +163,20 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
wg.Wait()
}
func TestKiroHashUniqueness(t *testing.T) {
func TestKiroHashStability(t *testing.T) {
fm := NewFingerprintManager()
hashes := make(map[string]bool)
for i := 0; i < 100; i++ {
fp := fm.GetFingerprint("token" + string(rune(i)))
if hashes[fp.KiroHash] {
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
}
hashes[fp.KiroHash] = true
// Same token should always return same hash
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1.KiroHash != fp2.KiroHash {
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
}
// Different tokens should have different hashes
fp3 := fm.GetFingerprint("token2")
if fp1.KiroHash == fp3.KiroHash {
t.Errorf("different tokens should have different hashes")
}
}
@@ -220,8 +189,590 @@ func TestKiroHashFormat(t *testing.T) {
}
for _, c := range fp.KiroHash {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character in KiroHash: %c", c)
}
}
}
func TestGlobalFingerprintManager(t *testing.T) {
fm1 := GlobalFingerprintManager()
fm2 := GlobalFingerprintManager()
if fm1 == nil {
t.Fatal("expected non-nil GlobalFingerprintManager")
}
if fm1 != fm2 {
t.Error("expected GlobalFingerprintManager to return same instance")
}
}
func TestSetOIDCHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
SetOIDCHeaders(req)
if req.Header.Get("Content-Type") != "application/json" {
t.Error("expected Content-Type header to be set")
}
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/sso-oidc") {
t.Errorf("User-Agent should contain api name: %s", ua)
}
if req.Header.Get("amz-sdk-invocation-id") == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestBuildURL(t *testing.T) {
tests := []struct {
name string
endpoint string
path string
queryParams map[string]string
want string
wantContains []string
}{
{
name: "no query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: nil,
want: "https://api.example.com/getUsageLimits",
},
{
name: "empty query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{},
want: "https://api.example.com/getUsageLimits",
},
{
name: "single query param",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
{
name: "multiple query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
},
wantContains: []string{
"https://api.example.com/getUsageLimits?",
"origin=AI_EDITOR",
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
"resourceType=AGENTIC_REQUEST",
},
},
{
name: "omit empty params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"profileArn": "",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
if tt.want != "" {
if got != tt.want {
t.Errorf("buildURL() = %v, want %v", got, tt.want)
}
}
if tt.wantContains != nil {
for _, substr := range tt.wantContains {
if !strings.Contains(got, substr) {
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
}
}
}
})
}
}
func TestBuildUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
ua := fp.BuildUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"ua/2.1",
"os/",
"lang/js",
"md/nodejs#",
"api/codewhispererstreaming#",
"m/E",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(ua, part) {
t.Errorf("User-Agent missing required part %q: %s", part, ua)
}
}
}
func TestBuildAmzUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
amzUA := fp.BuildAmzUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(amzUA, part) {
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
}
}
// Amz-User-Agent should be shorter than User-Agent
ua := fp.BuildUserAgent()
if len(amzUA) >= len(ua) {
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
}
}
func TestSetRuntimeHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
accessToken := "test-access-token-1234567890"
clientID := "test-client-id-12345"
accountKey := GenerateAccountKey(clientID)
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
setRuntimeHeaders(req, accessToken, accountKey)
// Check Authorization header
if req.Header.Get("Authorization") != "Bearer "+accessToken {
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
}
// Check x-amz-user-agent header
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE-") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
if !strings.Contains(amzUA, machineID) {
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
}
// Check User-Agent header
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/codewhispererruntime#") {
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
}
if !strings.Contains(ua, "m/N,E") {
t.Errorf("User-Agent should contain m/N,E: %s", ua)
}
// Check amz-sdk-invocation-id (should be a UUID)
invocationID := req.Header.Get("amz-sdk-invocation-id")
if invocationID == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if len(invocationID) != 36 {
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
}
// Check amz-sdk-request
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestSDKVersionsAreValid(t *testing.T) {
// Verify all OIDC SDK versions match expected format (3.xxx.x)
for _, v := range oidcSDKVersions {
if !strings.HasPrefix(v, "3.") {
t.Errorf("OIDC SDK version should start with 3.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
}
}
for _, v := range runtimeSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
}
}
for _, v := range streamingSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
}
}
}
func TestKiroVersionsAreValid(t *testing.T) {
// Verify all Kiro versions match expected format (0.x.xxx)
for _, v := range kiroVersions {
if !strings.HasPrefix(v, "0.") {
t.Errorf("Kiro version should start with 0.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Kiro version should have 3 parts: %s", v)
}
}
}
func TestNodeVersionsAreValid(t *testing.T) {
// Verify all Node versions match expected format (xx.xx.x)
for _, v := range nodeVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Node version should have 3 parts: %s", v)
}
// Should be Node 20.x or 22.x
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
}
}
}
func TestFingerprintManager_SetConfig(t *testing.T) {
fm := NewFingerprintManager()
// Without config, should generate random fingerprint
fp1 := fm.GetFingerprint("token1")
if fp1 == nil {
t.Fatal("expected non-nil fingerprint")
}
// Set config with all fields
cfg := &FingerprintConfig{
OIDCSDKVersion: "3.999.0",
RuntimeSDKVersion: "9.9.9",
StreamingSDKVersion: "8.8.8",
OSType: "darwin",
OSVersion: "99.0.0",
NodeVersion: "99.99.99",
KiroVersion: "9.9.999",
KiroHash: "customhash123",
}
fm.SetConfig(cfg)
// After setting config, should use config values
fp2 := fm.GetFingerprint("token2")
if fp2.OIDCSDKVersion != "3.999.0" {
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
}
if fp2.RuntimeSDKVersion != "9.9.9" {
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
}
if fp2.StreamingSDKVersion != "8.8.8" {
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
}
if fp2.OSType != "darwin" {
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
}
if fp2.OSVersion != "99.0.0" {
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
}
if fp2.NodeVersion != "99.99.99" {
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
}
if fp2.KiroVersion != "9.9.999" {
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
}
if fp2.KiroHash != "customhash123" {
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
}
}
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
fm := NewFingerprintManager()
// Set config with only some fields
cfg := &FingerprintConfig{
KiroVersion: "1.2.345",
KiroHash: "myhash",
// Other fields empty - should use random
}
fm.SetConfig(cfg)
fp := fm.GetFingerprint("token1")
// Configured fields should use config values
if fp.KiroVersion != "1.2.345" {
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
}
if fp.KiroHash != "myhash" {
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
}
// Empty fields should be randomly selected (non-empty)
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
}
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
fm := NewFingerprintManager()
// Get fingerprint before config
fp1 := fm.GetFingerprint("token1")
originalHash := fp1.KiroHash
// Set config
cfg := &FingerprintConfig{
KiroHash: "newcustomhash",
}
fm.SetConfig(cfg)
// Same token should now return different fingerprint (cache cleared)
fp2 := fm.GetFingerprint("token1")
if fp2.KiroHash == originalHash {
t.Error("expected cache to be cleared after SetConfig")
}
if fp2.KiroHash != "newcustomhash" {
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
}
}
func TestGenerateAccountKey(t *testing.T) {
tests := []struct {
name string
seed string
check func(t *testing.T, result string)
}{
{
name: "Empty seed",
seed: "",
check: func(t *testing.T, result string) {
if result == "" {
t.Error("expected non-empty result for empty seed")
}
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
},
},
{
name: "Simple seed",
seed: "test-client-id",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
// Verify it's valid hex
for _, c := range result {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character: %c", c)
}
}
},
},
{
name: "Same seed produces same result",
seed: "deterministic-seed",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("deterministic-seed")
if result != result2 {
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
}
},
},
{
name: "Different seeds produce different results",
seed: "seed-one",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("seed-two")
if result == result2 {
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateAccountKey(tt.seed)
tt.check(t, result)
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
clientID string
refreshToken string
check func(t *testing.T, result string)
}{
{
name: "Priority 1: clientID when both provided",
clientID: "client-id-123",
refreshToken: "refresh-token-456",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("client-id-123")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 2: refreshToken when clientID is empty",
clientID: "",
refreshToken: "refresh-token-789",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("refresh-token-789")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 3: random when both empty",
clientID: "",
refreshToken: "",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
// Should be different each time (random UUID)
result2 := GetAccountKey("", "")
if result == result2 {
t.Log("warning: random keys are the same (possible but unlikely)")
}
},
},
{
name: "clientID only",
clientID: "solo-client-id",
refreshToken: "",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-client-id")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "refreshToken only",
clientID: "",
refreshToken: "solo-refresh-token",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-refresh-token")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetAccountKey(tt.clientID, tt.refreshToken)
tt.check(t, result)
})
}
}
func TestGetAccountKey_Deterministic(t *testing.T) {
// Verify that GetAccountKey produces deterministic results for same inputs
clientID := "test-client-id-abc"
refreshToken := "test-refresh-token-xyz"
// Call multiple times with same inputs
results := make([]string, 10)
for i := range 10 {
results[i] = GetAccountKey(clientID, refreshToken)
}
// All results should be identical
for i := 1; i < 10; i++ {
if results[i] != results[0] {
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
}
}
}
func TestFingerprintDeterministic(t *testing.T) {
// Verify that fingerprints are deterministic based on accountKey
fm := NewFingerprintManager()
accountKey := GenerateAccountKey("test-client-id")
// Get fingerprint multiple times
fp1 := fm.GetFingerprint(accountKey)
fp2 := fm.GetFingerprint(accountKey)
// Should be the same pointer (cached)
if fp1 != fp2 {
t.Error("expected same fingerprint pointer for same key")
}
// Create new manager and verify same values
fm2 := NewFingerprintManager()
fp3 := fm2.GetFingerprint(accountKey)
// Values should be identical (deterministic generation)
if fp1.KiroHash != fp3.KiroHash {
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
}
if fp1.OSType != fp3.OSType {
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
}
if fp1.OSVersion != fp3.OSVersion {
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
}
if fp1.KiroVersion != fp3.KiroVersion {
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
}
if fp1.NodeVersion != fp3.NodeVersion {
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
}
}

View File

@@ -23,10 +23,10 @@ import (
const (
// Kiro auth endpoint
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// Default callback port
defaultCallbackPort = 9876
// Auth timeout
authTimeout = 10 * time.Minute
)
@@ -41,8 +41,10 @@ type KiroTokenResponse struct {
// KiroOAuth handles the OAuth flow for Kiro authentication.
type KiroOAuth struct {
httpClient *http.Client
cfg *config.Config
httpClient *http.Client
cfg *config.Config
machineID string
kiroVersion string
}
// NewKiroOAuth creates a new Kiro OAuth handler.
@@ -51,9 +53,12 @@ func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &KiroOAuth{
httpClient: client,
cfg: cfg,
httpClient: client,
cfg: cfg,
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
@@ -190,7 +195,8 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -256,11 +262,8 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
}
req.Header.Set("Content-Type", "application/json")
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
// This helps avoid 403 errors from server-side User-Agent validation
userAgent := buildKiroUserAgent(tokenKey)
req.Header.Set("User-Agent", userAgent)
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -301,19 +304,6 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke
}, nil
}
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
// Otherwise generates a simple KiroIDE User-Agent.
func buildKiroUserAgent(tokenKey string) string {
if tokenKey != "" {
fm := NewFingerprintManager()
fp := fm.GetFingerprint(tokenKey)
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
}
// Default KiroIDE User-Agent matching kiro-openai-gateway format
return "KiroIDE-0.7.45-cli-proxy-api"
}
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {

View File

@@ -35,35 +35,35 @@ const (
)
type webAuthSession struct {
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
}
type OAuthWebHandler struct {
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
}
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
@@ -104,7 +104,7 @@ func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
method := c.Query("method")
if method == "" {
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
return
@@ -138,7 +138,7 @@ func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
}
socialClient := NewSocialAuthClient(h.cfg)
var provider string
if method == "google" {
provider = string(ProviderGoogle)
@@ -373,22 +373,28 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
// Fetch profileArn for IDC
var profileArn string
if session.authMethod == "idc" {
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
}
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
h.mu.Lock()
session.status = statusSuccess
@@ -442,7 +448,7 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
fileName := GenerateTokenFileName(tokenData)
authFilePath := filepath.Join(authDir, fileName)
// Convert to storage format and save
storage := &KiroTokenStorage{
Type: "kiro",
@@ -459,12 +465,12 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
if err := storage.SaveTokenToFile(authFilePath); err != nil {
log.Errorf("OAuth Web: failed to save token to file: %v", err)
return
}
log.Infof("OAuth Web: token saved to %s", authFilePath)
}

View File

@@ -10,14 +10,14 @@ import (
log "github.com/sirupsen/logrus"
)
// RefreshManager 是后台刷新器的单例管理器
// RefreshManager is a singleton manager for background token refreshing.
type RefreshManager struct {
mu sync.Mutex
refresher *BackgroundRefresher
ctx context.Context
cancel context.CancelFunc
started bool
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
}
var (
@@ -25,7 +25,7 @@ var (
managerOnce sync.Once
)
// GetRefreshManager 获取全局刷新管理器实例
// GetRefreshManager returns the global RefreshManager singleton.
func GetRefreshManager() *RefreshManager {
managerOnce.Do(func() {
globalRefreshManager = &RefreshManager{}
@@ -33,9 +33,7 @@ func GetRefreshManager() *RefreshManager {
return globalRefreshManager
}
// Initialize 初始化后台刷新器
// baseDir: token 文件所在的目录
// cfg: 应用配置
// Initialize sets up the background refresher.
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -58,18 +56,16 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
baseDir = resolvedBaseDir
}
// 创建 token 存储库
repo := NewFileTokenRepository(baseDir)
// 创建后台刷新器,配置参数
opts := []RefresherOption{
WithInterval(time.Minute), // 每分钟检查一次
WithBatchSize(50), // 每批最多处理 50 个 token
WithConcurrency(10), // 最多 10 个并发刷新
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
WithInterval(time.Minute),
WithBatchSize(50),
WithConcurrency(10),
WithConfig(cfg),
}
// 如果已设置回调,传递给 BackgroundRefresher
// Pass callback to BackgroundRefresher if already set
if m.onTokenRefreshed != nil {
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
}
@@ -80,7 +76,7 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
return nil
}
// Start 启动后台刷新
// Start begins background token refreshing.
func (m *RefreshManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -102,7 +98,7 @@ func (m *RefreshManager) Start() {
log.Info("refresh manager: background refresh started")
}
// Stop 停止后台刷新
// Stop halts background token refreshing.
func (m *RefreshManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
@@ -123,14 +119,14 @@ func (m *RefreshManager) Stop() {
log.Info("refresh manager: background refresh stopped")
}
// IsRunning 检查后台刷新是否正在运行
// IsRunning reports whether background refreshing is active.
func (m *RefreshManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.started
}
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
// UpdateBaseDir changes the token directory at runtime.
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -143,16 +139,15 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) {
}
}
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
// 可以在任何时候调用,支持运行时更新回调
// callback: 回调函数,接收 tokenID文件名和新的 token 数据
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
// Can be called at any time; supports runtime callback updates.
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
m.mu.Lock()
defer m.mu.Unlock()
m.onTokenRefreshed = callback
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
// Update the refresher's callback in a thread-safe manner if already created
if m.refresher != nil {
m.refresher.callbackMu.Lock()
m.refresher.onTokenRefreshed = callback
@@ -162,8 +157,11 @@ func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, token
log.Debug("refresh manager: token refresh callback registered")
}
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
// InitializeAndStart initializes and starts background refreshing (convenience method).
func InitializeAndStart(baseDir string, cfg *config.Config) {
// Initialize global fingerprint config
initGlobalFingerprintConfig(cfg)
manager := GetRefreshManager()
if err := manager.Initialize(baseDir, cfg); err != nil {
log.Errorf("refresh manager: initialization failed: %v", err)
@@ -172,7 +170,31 @@ func InitializeAndStart(baseDir string, cfg *config.Config) {
manager.Start()
}
// StopGlobalRefreshManager 停止全局刷新管理器
// initGlobalFingerprintConfig loads fingerprint settings from application config.
func initGlobalFingerprintConfig(cfg *config.Config) {
if cfg == nil || cfg.KiroFingerprint == nil {
return
}
fpCfg := cfg.KiroFingerprint
SetGlobalFingerprintConfig(&FingerprintConfig{
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
OSType: fpCfg.OSType,
OSVersion: fpCfg.OSVersion,
NodeVersion: fpCfg.NodeVersion,
KiroVersion: fpCfg.KiroVersion,
KiroHash: fpCfg.KiroHash,
})
log.Debug("kiro: global fingerprint config loaded")
}
// InitFingerprintConfig initializes the global fingerprint config from application config.
func InitFingerprintConfig(cfg *config.Config) {
initGlobalFingerprintConfig(cfg)
}
// StopGlobalRefreshManager stops the global refresh manager.
func StopGlobalRefreshManager() {
if globalRefreshManager != nil {
globalRefreshManager.Stop()

View File

@@ -84,6 +84,8 @@ type SocialAuthClient struct {
httpClient *http.Client
cfg *config.Config
protocolHandler *ProtocolHandler
machineID string
kiroVersion string
}
// NewSocialAuthClient creates a new social auth client.
@@ -92,10 +94,13 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &SocialAuthClient{
httpClient: client,
cfg: cfg,
protocolHandler: NewProtocolHandler(),
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
@@ -229,7 +234,8 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
@@ -269,7 +275,8 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
@@ -466,7 +473,7 @@ func forceDefaultProtocolHandler() {
if runtime.GOOS != "linux" {
return // Non-Linux platforms use different handler mechanisms
}
// Set our handler as default using xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {

View File

@@ -14,6 +14,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
@@ -40,21 +41,13 @@ const (
// Authorization code flow callback
authCodeCallbackPath = "/oauth/callback"
authCodeCallbackPort = 19877
// User-Agent to match official Kiro IDE
kiroUserAgent = "KiroIDE"
// IDC token refresh headers (matching Kiro IDE behavior)
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
)
// Sentinel errors for OIDC token polling
var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrSlowDown = errors.New("slow_down")
)
// SSOOIDCClient handles AWS SSO OIDC authentication.
type SSOOIDCClient struct {
httpClient *http.Client
cfg *config.Config
@@ -74,10 +67,10 @@ func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
// RegisterClientResponse from AWS SSO OIDC.
type RegisterClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
}
// StartDeviceAuthResponse from AWS SSO OIDC.
@@ -174,8 +167,7 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -220,8 +212,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, cli
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -267,8 +258,7 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -311,8 +301,11 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
return &result, nil
}
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region.
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
if region == "" {
region = defaultIDCRegion
}
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
@@ -331,18 +324,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
if err != nil {
return nil, err
}
// Set headers matching kiro2api's IDC token refresh
// These headers are required for successful IDC token refresh
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
req.Header.Set("Connection", "keep-alive")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "*")
req.Header.Set("sec-fetch-mode", "cors")
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -469,10 +451,10 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
// Fetch user email
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -502,12 +484,36 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
return nil, fmt.Errorf("authorization timed out")
}
// IDCLoginOptions holds optional parameters for IDC login.
type IDCLoginOptions struct {
StartURL string // Pre-configured start URL (skips prompt if set)
Region string // OIDC region for login and token refresh (defaults to us-east-1)
UseDeviceCode bool // Use Device Code flow instead of Auth Code flow
}
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
// Options can be provided to pre-configure IDC parameters (startURL, region).
// If StartURL is provided in opts, IDC flow is used directly without prompting.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// If IDC options with StartURL are provided, skip method selection and use IDC directly
if opts != nil && opts.StartURL != "" {
region := opts.Region
if region == "" {
region = defaultIDCRegion
}
fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL)
fmt.Printf(" Region: %s\n", region)
if opts.UseDeviceCode {
return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region)
}
return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region)
}
// Prompt for login method
options := []string{
"Use with Builder ID (personal AWS account)",
@@ -520,15 +526,41 @@ func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroToke
return c.LoginWithBuilderID(ctx)
}
// IDC flow - prompt for start URL and region
fmt.Println()
startURL := promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
// IDC flow - use pre-configured values or prompt
var startURL, region string
if opts != nil {
startURL = opts.StartURL
region = opts.Region
}
region := promptInput("? Enter Region", defaultIDCRegion)
fmt.Println()
// Use pre-configured startURL or prompt
if startURL == "" {
startURL = promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
}
} else {
fmt.Printf(" Using pre-configured Start URL: %s\n", startURL)
}
// Use pre-configured region or prompt
if region == "" {
region = promptInput("? Enter Region", defaultIDCRegion)
} else {
fmt.Printf(" Using pre-configured Region: %s\n", region)
}
if opts != nil && opts.UseDeviceCode {
return c.LoginWithIDCAndOptions(ctx, startURL, region)
}
return c.LoginWithIDCAuthCode(ctx, startURL, region)
}
// LoginWithIDCAndOptions performs IDC login with the specified region.
func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
return c.LoginWithIDC(ctx, startURL, region)
}
@@ -550,8 +582,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -594,8 +625,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -639,8 +669,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -702,13 +731,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
if err != nil {
return nil, err
}
// Set headers matching Kiro IDE behavior for better compatibility
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept", "*/*")
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -835,12 +858,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
log.Debugf("Failed to close browser: %v", err)
}
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -850,7 +869,7 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ProfileArn: "", // Builder ID has no profile
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
@@ -859,15 +878,15 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
}
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
// Falls back to JWT parsing if userinfo fails.
@@ -931,20 +950,64 @@ func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken str
return ""
}
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
// Try ListProfiles API first
profileArn := c.tryListProfiles(ctx, accessToken)
// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API.
// This is used to get profileArn for imported accounts that may not have it.
func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string {
profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken)
if profileArn != "" {
return profileArn
}
// Fallback: Try ListAvailableCustomizations
return c.tryListCustomizations(ctx, accessToken)
return c.tryListProfilesLegacy(ctx, accessToken)
}
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}"))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/json")
accountKey := GetAccountKey(clientID, refreshToken)
setRuntimeHeaders(req, accessToken, accountKey)
resp, err := c.httpClient.Do(req)
if err != nil {
log.Debugf("ListAvailableProfiles request failed: %v", err)
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListAvailableProfiles response: %s", string(respBody))
var result struct {
Profiles []struct {
Arn string `json:"arn"`
ProfileName string `json:"profileName"`
} `json:"profiles"`
NextToken *string `json:"nextToken"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
log.Debugf("ListAvailableProfiles parse error: %v", err)
return ""
}
if len(result.Profiles) > 0 {
log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn)
return result.Profiles[0].Arn
}
return ""
}
func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
@@ -954,7 +1017,9 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
// Use the legacy CodeWhisperer endpoint for JSON-RPC style requests.
// The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers.
req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body)))
if err != nil {
return ""
}
@@ -973,11 +1038,11 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListProfiles response: %s", string(respBody))
log.Debugf("ListProfiles (legacy) response: %s", string(respBody))
var result struct {
Profiles []struct {
@@ -1001,63 +1066,6 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string)
return ""
}
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
body, err := json.Marshal(payload)
if err != nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
var result struct {
Customizations []struct {
Arn string `json:"arn"`
} `json:"customizations"`
ProfileArn string `json:"profileArn"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if result.ProfileArn != "" {
return result.ProfileArn
}
if len(result.Customizations) > 0 {
return result.Customizations[0].Arn
}
return ""
}
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
payload := map[string]interface{}{
@@ -1078,8 +1086,7 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -1105,6 +1112,53 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU
return &result, nil
}
func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]interface{}{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"authorization_code", "refresh_token"},
"redirectUris": []string{redirectURI},
"issuerUrl": issuerUrl,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
}
var result RegisterClientResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// AuthCodeCallbackResult contains the result from authorization code callback.
type AuthCodeCallbackResult struct {
Code string
@@ -1128,6 +1182,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
resultChan := make(chan AuthCodeCallbackResult, 1)
doneChan := make(chan struct{})
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
@@ -1147,6 +1202,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- AuthCodeCallbackResult{Error: errParam}
close(doneChan)
return
}
@@ -1156,6 +1212,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
close(doneChan)
return
}
@@ -1164,6 +1221,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
<script>window.close();</script></body></html>`)
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
close(doneChan)
})
server.Handler = mux
@@ -1178,7 +1236,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte
select {
case <-ctx.Done():
case <-time.After(10 * time.Minute):
case <-resultChan:
case <-doneChan:
}
_ = server.Shutdown(context.Background())
}()
@@ -1227,8 +1285,54 @@ func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, c
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"code": code,
"codeVerifier": codeVerifier,
"redirectUri": redirectURI,
"grantType": "authorization_code",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
SetOIDCHeaders(req)
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -1352,12 +1456,118 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
fmt.Println("\n✓ Authentication successful!")
// Step 8: Get profile ARN
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: "", // Builder ID has no profile
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
if region == "" {
region = defaultIDCRegion
}
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
}
state, err := generateStateForAuthCode()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
fmt.Println("\nStarting callback server...")
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
fmt.Println("Registering client...")
regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
log.Debugf("Client registered: %s", regResp.ClientID)
endpoint := getOIDCEndpoint(region)
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge)
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Println(" Opening browser for authentication...")
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n URL: %s\n\n", authURL)
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
} else {
browser.SetIncognitoMode(true)
}
if err := browser.OpenURL(authURL); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" ⚠ Could not open browser automatically.")
fmt.Println(" Please open the URL above in your browser manually.")
} else {
fmt.Println(" (Browser opened automatically)")
}
fmt.Println("\n Waiting for authorization callback...")
select {
case <-ctx.Done():
browser.CloseBrowser()
return nil, ctx.Err()
case <-time.After(10 * time.Minute):
browser.CloseBrowser()
return nil, fmt.Errorf("authorization timed out")
case result := <-resultChan:
if result.Error != "" {
browser.CloseBrowser()
return nil, fmt.Errorf("authorization failed: %s", result.Error)
}
fmt.Println("\n✓ Authorization received!")
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
fmt.Println("Exchanging code for tokens...")
tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
fmt.Println("Fetching profile information...")
profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
@@ -1369,12 +1579,25 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
AuthMethod: "idc",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
StartURL: startURL,
Region: region,
}, nil
}
}
func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scopes", scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode())
}

View File

@@ -0,0 +1,261 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type recordingRoundTripper struct {
lastReq *http.Request
}
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.lastReq = req
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
}
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("", "refresh-token-789")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
var capturedReq struct {
Method string
Path string
Headers http.Header
Body map[string]interface{}
}
mockResp := RegisterClientResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientIDIssuedAt: 1700000000,
ClientSecretExpiresAt: 1700086400,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq.Method = r.Method
capturedReq.Path = r.URL.Path
capturedReq.Headers = r.Header.Clone()
bodyBytes, _ := io.ReadAll(r.Body)
json.Unmarshal(bodyBytes, &capturedReq.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResp)
}))
defer ts.Close()
// Extract host to build a region that resolves to our test server.
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
// instead inject the test server URL directly via a custom HTTP client transport.
client := &SSOOIDCClient{
httpClient: ts.Client(),
}
// We need to route the request to our test server. Use a transport that rewrites the URL.
client.httpClient.Transport = &rewriteTransport{
base: ts.Client().Transport,
targetURL: ts.URL,
}
resp, err := client.RegisterClientForAuthCodeWithIDC(
context.Background(),
"http://127.0.0.1:19877/oauth/callback",
"https://my-idc-instance.awsapps.com/start",
"us-east-1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify request method and path
if capturedReq.Method != http.MethodPost {
t.Errorf("method = %q, want POST", capturedReq.Method)
}
if capturedReq.Path != "/client/register" {
t.Errorf("path = %q, want /client/register", capturedReq.Path)
}
// Verify headers
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type = %q, want application/json", ct)
}
ua := capturedReq.Headers.Get("User-Agent")
if !strings.Contains(ua, "KiroIDE") {
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
}
if !strings.Contains(ua, "sso-oidc") {
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
}
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
if !strings.Contains(xua, "KiroIDE") {
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
}
// Verify body fields
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
}
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
t.Errorf("clientType = %q, want %q", v, "public")
}
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
}
// Verify scopes array
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
if !ok || len(scopesRaw) != 5 {
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
}
expectedScopes := []string{
"codewhisperer:completions", "codewhisperer:analysis",
"codewhisperer:conversations", "codewhisperer:transformations",
"codewhisperer:taskassist",
}
for i, s := range expectedScopes {
if scopesRaw[i].(string) != s {
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
}
}
// Verify grantTypes
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
if !ok || len(grantTypesRaw) != 2 {
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
}
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
}
// Verify redirectUris
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
if !ok || len(redirectRaw) != 1 {
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
}
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
}
// Verify response parsing
if resp.ClientID != "test-client-id" {
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
}
if resp.ClientSecret != "test-client-secret" {
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
}
if resp.ClientIDIssuedAt != 1700000000 {
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
}
if resp.ClientSecretExpiresAt != 1700086400 {
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
}
}
// rewriteTransport redirects all requests to the test server URL.
type rewriteTransport struct {
base http.RoundTripper
targetURL string
}
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
target, _ := url.Parse(t.targetURL)
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
if t.base != nil {
return t.base.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
func TestBuildAuthorizationURL(t *testing.T) {
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
endpoint := "https://oidc.us-east-1.amazonaws.com"
redirectURI := "http://127.0.0.1:19877/oauth/callback"
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
// Verify colons and commas in scopes are percent-encoded
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
}
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
}
// Parse back and verify all parameters round-trip correctly
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("failed to parse auth URL: %v", err)
}
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
}
q := parsed.Query()
checks := map[string]string{
"response_type": "code",
"client_id": "test-client-id",
"redirect_uri": redirectURI,
"scopes": scopes,
"state": "random-state",
"code_challenge": "test-challenge",
"code_challenge_method": "S256",
}
for key, want := range checks {
if got := q.Get(key); got != want {
t.Errorf("%s = %q, want %q", key, got, want)
}
}
}

View File

@@ -29,7 +29,7 @@ type KiroTokenStorage struct {
ClientID string `json:"client_id,omitempty"`
// ClientSecret is the OAuth client secret (required for token refresh)
ClientSecret string `json:"client_secret,omitempty"`
// Region is the AWS region
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
// StartURL is the AWS Identity Center start URL (for IDC auth)
StartURL string `json:"start_url,omitempty"`

View File

@@ -200,36 +200,22 @@ func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
}
// 解析各字段
if v, ok := metadata["access_token"].(string); ok {
token.AccessToken = v
}
if v, ok := metadata["refresh_token"].(string); ok {
token.RefreshToken = v
}
if v, ok := metadata["client_id"].(string); ok {
token.ClientID = v
}
if v, ok := metadata["client_secret"].(string); ok {
token.ClientSecret = v
}
if v, ok := metadata["region"].(string); ok {
token.Region = v
}
if v, ok := metadata["start_url"].(string); ok {
token.StartURL = v
}
if v, ok := metadata["provider"].(string); ok {
token.Provider = v
}
token.AccessToken, _ = metadata["access_token"].(string)
token.RefreshToken, _ = metadata["refresh_token"].(string)
token.ClientID, _ = metadata["client_id"].(string)
token.ClientSecret, _ = metadata["client_secret"].(string)
token.Region, _ = metadata["region"].(string)
token.StartURL, _ = metadata["start_url"].(string)
token.Provider, _ = metadata["provider"].(string)
// 解析时间字段
if v, ok := metadata["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
token.ExpiresAt = t
}
}
if v, ok := metadata["last_refresh"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
token.LastVerified = t
}
}

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -51,14 +50,12 @@ type QuotaStatus struct {
// UsageChecker provides methods for checking token quota usage.
type UsageChecker struct {
httpClient *http.Client
endpoint string
}
// NewUsageChecker creates a new UsageChecker instance.
func NewUsageChecker(cfg *config.Config) *UsageChecker {
return &UsageChecker{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
@@ -66,7 +63,6 @@ func NewUsageChecker(cfg *config.Config) *UsageChecker {
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
return &UsageChecker{
httpClient: client,
endpoint: awsKiroEndpoint,
}
}
@@ -80,26 +76,23 @@ func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData)
return nil, fmt.Errorf("access token is empty")
}
payload := map[string]interface{}{
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Use endpoint from profileArn if available
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", targetGetUsage)
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
req.Header.Set("Accept", "application/json")
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := c.httpClient.Do(req)
if err != nil {

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -206,3 +206,52 @@ func DoKiroImport(cfg *config.Config, options *LoginOptions) {
}
fmt.Println("Kiro token import successful!")
}
func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) {
if options == nil {
options = &LoginOptions{}
}
if startURL == "" {
log.Errorf("Kiro IDC login requires --kiro-idc-start-url")
fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start")
return
}
manager := newAuthManager()
authenticator := sdkAuth.NewKiroAuthenticator()
metadata := map[string]string{
"start-url": startURL,
"region": region,
"flow": flow,
}
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: metadata,
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro IDC authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure your IDC Start URL is correct")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device")
return
}
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro IDC authentication successful!")
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
@@ -27,11 +28,8 @@ import (
)
const (
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
geminiCLIApiClient = "gl-node/22.17.0"
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
)
type projectSelectionRequiredError struct{}
@@ -409,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string
return fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
@@ -630,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
@@ -651,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", geminiCLIUserAgent)
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo = httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -69,6 +69,9 @@ type Config struct {
// RequestRetry defines the retry times when the request failed.
RequestRetry int `yaml:"request-retry" json:"request-retry"`
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
@@ -87,6 +90,10 @@ type Config struct {
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
// KiroFingerprint defines a global fingerprint configuration for all Kiro requests.
// When set, all Kiro requests will use this fixed fingerprint instead of random generation.
KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"`
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
@@ -477,6 +484,9 @@ type KiroKey struct {
// Region is the AWS region (default: us-east-1).
Region string `yaml:"region,omitempty" json:"region,omitempty"`
// StartURL is the IAM Identity Center (IDC) start URL for SSO login.
StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this configuration.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
@@ -489,6 +499,20 @@ type KiroKey struct {
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
}
// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests.
// When configured, all Kiro requests will use this fixed fingerprint instead of random generation.
// Empty fields will fall back to random selection from built-in pools.
type KiroFingerprintConfig struct {
OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"`
RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"`
StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"`
OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"`
OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"`
NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"`
KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"`
KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"`
}
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -555,16 +579,6 @@ func LoadConfig(configFile string) (*Config, error) {
// If optional is true and the file is missing, it returns an empty Config.
// If optional is true and the file is empty or invalid, it returns an empty Config.
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
// Reason: avoid mutating config.yaml during server startup.
// Re-enable the block below if automatic startup migration is needed again.
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
// // Log warning but don't fail - config loading should still work
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
// } else if migrated {
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
// }
// Read the entire configuration file into memory.
data, err := os.ReadFile(configFile)
if err != nil {
@@ -652,6 +666,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
}
if cfg.MaxRetryCredentials < 0 {
cfg.MaxRetryCredentials = 0
}
// Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys()
@@ -1648,9 +1666,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
srcIdx := findMapKeyIndex(srcRoot, key)
if srcIdx < 0 {
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
//
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
// When users delete the last channel from oauth-model-alias via the management API,
// we want that deletion to persist across hot reloads and restarts.
if key == "oauth-model-alias" {

View File

@@ -1,314 +0,0 @@
package config
import (
"os"
"strings"
"gopkg.in/yaml.v3"
)
// antigravityModelConversionTable maps old built-in aliases to actual model names
// for the antigravity channel during migration.
var antigravityModelConversionTable = map[string]string{
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
}
// defaultKiroAliases returns the default oauth-model-alias configuration
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
// names so that clients like Claude Code can use standard names directly.
func defaultKiroAliases() []OAuthModelAlias {
return []OAuthModelAlias{
// Sonnet 4.5
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
// Sonnet 4
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
// Opus 4.6
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
// Opus 4.5
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
// Haiku 4.5
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
}
}
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
// This keeps compatibility with clients (e.g. Claude Code) that use
// Anthropic-style model IDs like "claude-opus-4-6".
func defaultGitHubCopilotAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
}
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
}
}
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
// to oauth-model-alias at startup. Returns true if migration was performed.
//
// Migration flow:
// 1. Check if oauth-model-alias exists -> skip migration
// 2. Check if oauth-model-mappings exists -> convert and migrate
// - For antigravity channel, convert old built-in aliases to actual model names
//
// 3. Neither exists -> add default antigravity config
func MigrateOAuthModelAlias(configFile string) (bool, error) {
data, err := os.ReadFile(configFile)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
if len(data) == 0 {
return false, nil
}
// Parse YAML into node tree to preserve structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
return false, nil
}
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
return false, nil
}
rootMap := root.Content[0]
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
return false, nil
}
// Check if oauth-model-alias already exists
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
return false, nil
}
// Check if oauth-model-mappings exists
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
if oldIdx >= 0 {
// Migrate from old field
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
}
// Neither field exists - add default antigravity config
return addDefaultAntigravityConfig(configFile, &root, rootMap)
}
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
if oldIdx+1 >= len(rootMap.Content) {
return false, nil
}
oldValue := rootMap.Content[oldIdx+1]
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
return false, nil
}
// Parse the old aliases
oldAliases := parseOldAliasNode(oldValue)
if len(oldAliases) == 0 {
// Remove the old field and write
removeMapKeyByIndex(rootMap, oldIdx)
return writeYAMLNode(configFile, root)
}
// Convert model names for antigravity channel
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
for channel, entries := range oldAliases {
converted := make([]OAuthModelAlias, 0, len(entries))
for _, entry := range entries {
newEntry := OAuthModelAlias{
Name: entry.Name,
Alias: entry.Alias,
Fork: entry.Fork,
}
// Convert model names for antigravity channel
if strings.EqualFold(channel, "antigravity") {
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
newEntry.Name = actual
}
}
converted = append(converted, newEntry)
}
newAliases[channel] = converted
}
// For antigravity channel, supplement missing default aliases
if antigravityEntries, exists := newAliases["antigravity"]; exists {
// Build a set of already configured model names (upstream names)
configuredModels := make(map[string]bool, len(antigravityEntries))
for _, entry := range antigravityEntries {
configuredModels[entry.Name] = true
}
// Add missing default aliases
for _, defaultAlias := range defaultAntigravityAliases() {
if !configuredModels[defaultAlias.Name] {
antigravityEntries = append(antigravityEntries, defaultAlias)
}
}
newAliases["antigravity"] = antigravityEntries
}
// Build new node
newNode := buildOAuthModelAliasNode(newAliases)
// Replace old key with new key and value
rootMap.Content[oldIdx].Value = "oauth-model-alias"
rootMap.Content[oldIdx+1] = newNode
return writeYAMLNode(configFile, root)
}
// addDefaultAntigravityConfig adds the default antigravity configuration
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
defaults := map[string][]OAuthModelAlias{
"antigravity": defaultAntigravityAliases(),
}
newNode := buildOAuthModelAliasNode(defaults)
// Add new key-value pair
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
rootMap.Content = append(rootMap.Content, keyNode, newNode)
return writeYAMLNode(configFile, root)
}
// parseOldAliasNode parses the old oauth-model-mappings node structure
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
if node == nil || node.Kind != yaml.MappingNode {
return nil
}
result := make(map[string][]OAuthModelAlias)
for i := 0; i+1 < len(node.Content); i += 2 {
channelNode := node.Content[i]
entriesNode := node.Content[i+1]
if channelNode == nil || entriesNode == nil {
continue
}
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
continue
}
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
for _, entryNode := range entriesNode.Content {
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
continue
}
entry := parseAliasEntry(entryNode)
if entry.Name != "" && entry.Alias != "" {
entries = append(entries, entry)
}
}
if len(entries) > 0 {
result[channel] = entries
}
}
return result
}
// parseAliasEntry parses a single alias entry node
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
var entry OAuthModelAlias
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valNode := node.Content[i+1]
if keyNode == nil || valNode == nil {
continue
}
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
case "name":
entry.Name = strings.TrimSpace(valNode.Value)
case "alias":
entry.Alias = strings.TrimSpace(valNode.Value)
case "fork":
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
}
}
return entry
}
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
for channel, entries := range aliases {
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
for _, entry := range entries {
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
)
if entry.Fork {
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
)
}
entriesNode.Content = append(entriesNode.Content, entryNode)
}
node.Content = append(node.Content, channelNode, entriesNode)
}
return node
}
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return
}
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
return
}
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
}
// writeYAMLNode writes the YAML node tree back to file
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
f, err := os.Create(configFile)
if err != nil {
return false, err
}
defer f.Close()
enc := yaml.NewEncoder(f)
enc.SetIndent(2)
if err := enc.Encode(root); err != nil {
return false, err
}
if err := enc.Close(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -1,245 +0,0 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-alias:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration when oauth-model-alias already exists")
}
// Verify file unchanged
data, _ := os.ReadFile(configFile)
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("file should still contain oauth-model-alias")
}
}
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-mappings:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
fork: true
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify new field exists and old field removed
data, _ := os.ReadFile(configFile)
if strings.Contains(string(data), "oauth-model-mappings:") {
t.Fatal("old field should be removed")
}
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("new field should exist")
}
// Parse and verify structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
t.Fatal(err)
}
}
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
// Use old model names that should be converted
content := `oauth-model-mappings:
antigravity:
- name: "gemini-2.5-computer-use-preview-10-2025"
alias: "computer-use"
- name: "gemini-3-pro-preview"
alias: "g3p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify model names were converted
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
}
if !strings.Contains(content, "gemini-3-pro-high") {
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
}
// Verify missing default aliases were supplemented
if !strings.Contains(content, "gemini-3-pro-image") {
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
}
if !strings.Contains(content, "gemini-3-flash") {
t.Fatal("expected missing default alias gemini-3-flash to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5") {
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
}
if !strings.Contains(content, "claude-opus-4-5-thinking") {
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
}
if !strings.Contains(content, "claude-opus-4-6-thinking") {
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
}
}
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to add default config")
}
// Verify default antigravity config was added
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "oauth-model-alias:") {
t.Fatal("expected oauth-model-alias to be added")
}
if !strings.Contains(content, "antigravity:") {
t.Fatal("expected antigravity channel to be added")
}
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
}
}
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
oauth-model-mappings:
gemini-cli:
- name: "test"
alias: "t"
api-keys:
- "key1"
- "key2"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify other config preserved
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "debug: true") {
t.Fatal("expected debug field to be preserved")
}
if !strings.Contains(content, "port: 8080") {
t.Fatal("expected port field to be preserved")
}
if !strings.Contains(content, "api-keys:") {
t.Fatal("expected api-keys field to be preserved")
}
}
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
t.Parallel()
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
if err != nil {
t.Fatalf("unexpected error for nonexistent file: %v", err)
}
if migrated {
t.Fatal("expected no migration for nonexistent file")
}
}
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration for empty file")
}
}

View File

@@ -1 +1 @@
[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude.","cache_control":{"type":"ephemeral"}}]
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -4,10 +4,98 @@
package misc
import (
"fmt"
"net/http"
"runtime"
"strings"
)
const (
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
GeminiCLIVersion = "0.31.0"
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
)
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
func geminiCLIOS() string {
switch runtime.GOOS {
case "windows":
return "win32"
default:
return runtime.GOOS
}
}
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
func geminiCLIArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "386":
return "x86"
default:
return runtime.GOARCH
}
}
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
func GeminiCLIUserAgent(model string) string {
if model == "" {
model = "unknown"
}
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
}
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
// proxy infrastructure, client identity, or browser fingerprints from an
// outgoing request. This ensures requests to upstream services look like they
// originate directly from a native client rather than a third-party client
// behind a reverse proxy.
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
if req == nil {
return
}
// --- Proxy tracing headers ---
req.Header.Del("X-Forwarded-For")
req.Header.Del("X-Forwarded-Host")
req.Header.Del("X-Forwarded-Proto")
req.Header.Del("X-Forwarded-Port")
req.Header.Del("X-Real-IP")
req.Header.Del("Forwarded")
req.Header.Del("Via")
// --- Client identity headers ---
req.Header.Del("X-Title")
req.Header.Del("X-Stainless-Lang")
req.Header.Del("X-Stainless-Package-Version")
req.Header.Del("X-Stainless-Os")
req.Header.Del("X-Stainless-Arch")
req.Header.Del("X-Stainless-Runtime")
req.Header.Del("X-Stainless-Runtime-Version")
req.Header.Del("Http-Referer")
req.Header.Del("Referer")
// --- Browser / Chromium fingerprint headers ---
// These are sent by Electron-based clients (e.g. CherryStudio) using the
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
req.Header.Del("Sec-Ch-Ua")
req.Header.Del("Sec-Ch-Ua-Mobile")
req.Header.Del("Sec-Ch-Ua-Platform")
req.Header.Del("Sec-Fetch-Mode")
req.Header.Del("Sec-Fetch-Site")
req.Header.Del("Sec-Fetch-Dest")
req.Header.Del("Priority")
// --- Encoding negotiation ---
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
req.Header.Del("Accept-Encoding")
}
// EnsureHeader ensures that a header exists in the target header map by checking
// multiple sources in order of priority: source headers, existing target headers,
// and finally the default value. It only sets the header if it's not already present

View File

@@ -696,6 +696,42 @@ func GetKiroModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-deepseek-3-2-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -916,19 +916,12 @@ func GetIFlowModels() []*ModelInfo {
Created int64
Thinking *ThinkingSupport
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
@@ -937,11 +930,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
@@ -970,21 +959,17 @@ type AntigravityModelConfig struct {
// Keys use upstream model names returned by the Antigravity models endpoint.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
}
}

View File

@@ -47,8 +47,10 @@ type ModelInfo struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
// SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO)
SupportedInputModalities []string `json:"supportedInputModalities,omitempty"`
// SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE)
SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
@@ -501,8 +503,11 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
if len(model.SupportedInputModalities) > 0 {
copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
}
return &copyModel
}
@@ -1089,6 +1094,12 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
}
if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = model.SupportedInputModalities
}
if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = model.SupportedOutputModalities
}
return result
default:

View File

@@ -8,6 +8,7 @@ import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"encoding/json"
"errors"
@@ -45,17 +46,87 @@ const (
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -72,6 +143,62 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
return &AntigravityExecutor{cfg: cfg}
}
// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests.
// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool
// (and the goroutines managing it) on every request.
var (
antigravityTransport *http.Transport
antigravityTransportOnce sync.Once
)
func cloneTransportWithHTTP11(base *http.Transport) *http.Transport {
if base == nil {
return nil
}
clone := base.Clone()
clone.ForceAttemptHTTP2 = false
// Wipe TLSNextProto to prevent implicit HTTP/2 upgrade.
clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
if clone.TLSClientConfig == nil {
clone.TLSClientConfig = &tls.Config{}
} else {
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
}
// Actively advertise only HTTP/1.1 in the ALPN handshake.
clone.TLSClientConfig.NextProtos = []string{"http/1.1"}
return clone
}
// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once.
func initAntigravityTransport() {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
base = &http.Transport{}
}
antigravityTransport = cloneTransportWithHTTP11(base)
}
// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity,
// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults.
// The underlying Transport is a singleton to avoid leaking connection pools.
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
antigravityTransportOnce.Do(initAntigravityTransport)
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
// If no transport is set, use the shared HTTP/1.1 transport.
if client.Transport == nil {
client.Transport = antigravityTransport
return client
}
// Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1.
if transport, ok := client.Transport.(*http.Transport); ok {
client.Transport = cloneTransportWithHTTP11(transport)
}
return client
}
// Identifier returns the executor identifier.
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
@@ -92,6 +219,8 @@ func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyau
}
// HttpRequest injects Antigravity credentials into the request and executes it.
// It uses a whitelist approach: all incoming headers are stripped and only
// the minimum set required by the Antigravity protocol is explicitly set.
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("antigravity executor: request is nil")
@@ -100,10 +229,29 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
// --- Whitelist: save only the headers we need from the original request ---
contentType := httpReq.Header.Get("Content-Type")
// Wipe ALL incoming headers
for k := range httpReq.Header {
delete(httpReq.Header, k)
}
// --- Set only the headers Antigravity actually sends ---
if contentType != "" {
httpReq.Header.Set("Content-Type", contentType)
}
// Content-Length is managed automatically by Go's http.Client from the Body
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
httpReq.Close = true // sends Connection: close
// Inject Authorization: Bearer <token>
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
@@ -115,7 +263,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
baseModel := thinking.ParseSuffix(req.Model).ModelName
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
if isClaude || strings.Contains(baseModel, "gemini-3-pro") {
if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
@@ -150,7 +298,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
@@ -292,7 +440,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
@@ -684,7 +832,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
@@ -886,7 +1034,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
payload = deleteJSONField(payload, "request.safetySettings")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -917,10 +1065,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
if errReq != nil {
return cliproxyexecutor.Response{}, errReq
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
httpReq.Header.Set("Accept", "application/json")
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
@@ -1006,28 +1154,34 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil {
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
return nil
}
if token == "" {
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
return nil
}
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0)
for idx, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
return nil
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload))
if errReq != nil {
return fallbackAntigravityPrimaryModels()
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
@@ -1038,15 +1192,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
return nil
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
return nil
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
@@ -1058,22 +1210,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
return nil
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
@@ -1085,7 +1242,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
continue
}
switch modelID {
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
modelCfg := modelConfig[modelID]
@@ -1107,6 +1264,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Build input modalities from upstream capability flags.
inputModalities := []string{"TEXT"}
if modelData.Get("supportsImages").Bool() {
inputModalities = append(inputModalities, "IMAGE")
}
if modelData.Get("supportsVideo").Bool() {
inputModalities = append(inputModalities, "VIDEO")
}
modelInfo.SupportedInputModalities = inputModalities
modelInfo.SupportedOutputModalities = []string{"TEXT"}
// Token limits from upstream.
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
modelInfo.InputTokenLimit = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
modelInfo.OutputTokenLimit = int(maxOut)
}
// Supported generation methods (Gemini v1beta convention).
modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"}
// Look up Thinking support from static config using upstream model name.
if modelCfg != nil {
if modelCfg.Thinking != nil {
@@ -1118,9 +1298,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return nil
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
@@ -1165,10 +1354,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
return auth, errReq
}
httpReq.Header.Set("Host", "oauth2.googleapis.com")
httpReq.Header.Set("User-Agent", defaultAntigravityAgent)
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Real Antigravity uses Go's default User-Agent for OAuth token refresh
httpReq.Header.Set("User-Agent", "Go-http-client/2.0")
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
return auth, errDo
@@ -1239,7 +1429,7 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
return nil
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
if errFetch != nil {
return errFetch
@@ -1293,7 +1483,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", modelName)
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
payloadStr := string(payload)
paths := make([]string, 0)
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
@@ -1307,18 +1497,18 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
}
if useAntigravitySchema {
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// if useAntigravitySchema {
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
for _, partResult := range systemInstructionPartsResult.Array() {
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
}
}
}
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
// for _, partResult := range systemInstructionPartsResult.Array() {
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
// }
// }
// }
if strings.Contains(modelName, "claude") {
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
@@ -1330,14 +1520,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
if errReq != nil {
return nil, errReq
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if stream {
httpReq.Header.Set("Accept", "text/event-stream")
} else {
httpReq.Header.Set("Accept", "application/json")
}
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
@@ -1549,7 +1735,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template, _ = sjson.Set(template, "requestType", "agent")
isImageModel := strings.Contains(modelName, "image")
var reqType string
if isImageModel {
reqType = "image_gen"
} else {
reqType = "agent"
}
template, _ = sjson.Set(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
@@ -1557,8 +1752,13 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
}
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
if isImageModel {
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
} else {
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
}
template, _ = sjson.Delete(template, "request.safetySettings")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
@@ -1572,6 +1772,10 @@ func generateRequestID() string {
return "agent-" + uuid.NewString()
}
func generateImageGenRequestID() string {
return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString())
}
func generateSessionID() string {
randSourceMutex.Lock()
n := randSource.Int63n(9_000_000_000_000_000_000)

View File

@@ -0,0 +1,90 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -6,9 +6,14 @@ import (
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/textproto"
"runtime"
"strings"
"time"
@@ -36,7 +41,9 @@ type ClaudeExecutor struct {
cfg *config.Config
}
const claudeToolPrefix = "proxy_"
// claudeToolPrefix is empty to match real Claude Code behavior (no tool name prefix).
// Previously "proxy_" was used but this is a detectable fingerprint difference.
const claudeToolPrefix = ""
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
@@ -130,6 +137,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
// Cloaking and ensureCacheControl may push the total over 4 when the client
// (e.g. Amp CLI) already sends multiple cache_control blocks.
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -171,11 +187,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return resp, err
@@ -271,6 +305,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -312,10 +352,28 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
@@ -420,6 +478,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
body = checkSystemInstructions(body)
}
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
body = enforceCacheControlLimit(body, 4)
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -459,9 +521,27 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := resp.Body
if ce := resp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(resp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
if errClose := resp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
@@ -696,23 +776,29 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
ginHeaders = ginCtx.Request.Header
}
promptCachingBeta := "prompt-caching-2024-07-31"
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14," + promptCachingBeta
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
baseBetas = val
if !strings.Contains(val, "oauth") {
baseBetas += ",oauth-2025-04-20"
}
}
if !strings.Contains(baseBetas, promptCachingBeta) {
baseBetas += "," + promptCachingBeta
hasClaude1MHeader := false
if ginHeaders != nil {
if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok {
hasClaude1MHeader = true
}
}
// Merge extra betas from request body
if len(extraBetas) > 0 {
// Merge extra betas from request body and request flags.
if len(extraBetas) > 0 || hasClaude1MHeader {
existingSet := make(map[string]bool)
for _, b := range strings.Split(baseBetas, ",") {
existingSet[strings.TrimSpace(b)] = true
betaName := strings.TrimSpace(b)
if betaName != "" {
existingSet[betaName] = true
}
}
for _, beta := range extraBetas {
beta = strings.TrimSpace(beta)
@@ -721,14 +807,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
existingSet[beta] = true
}
}
if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] {
baseBetas += ",context-1m-2025-08-07"
}
}
r.Header.Set("Anthropic-Beta", baseBetas)
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
@@ -737,7 +825,18 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
// For User-Agent, only forward the client's header if it's already a Claude Code client.
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
// to avoid leaking the real client identity during cloaking.
clientUA := ""
if ginHeaders != nil {
clientUA = ginHeaders.Get("User-Agent")
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
} else {
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
}
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
@@ -771,22 +870,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
}
func checkSystemInstructions(payload []byte) []byte {
system := gjson.GetBytes(payload, "system")
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
if system.IsArray() {
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
}
return true
})
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
} else {
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
return payload
return checkSystemInstructionsWithMode(payload, false)
}
func isClaudeOAuthToken(apiKey string) bool {
@@ -1060,33 +1144,73 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
return payload
}
// checkSystemInstructionsWithMode injects Claude Code system prompt.
// In strict mode, it replaces all user system messages.
// In non-strict mode (default), it prepends to existing system messages.
// generateBillingHeader creates the x-anthropic-billing-header text block that
// real Claude Code prepends to every system prompt array.
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
func generateBillingHeader(payload []byte) string {
// Generate a deterministic cch hash from the payload content (system + messages + tools).
// Real Claude Code uses a 5-char hex hash that varies per request.
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
buildBytes := make([]byte, 2)
_, _ = rand.Read(buildBytes)
buildHash := hex.EncodeToString(buildBytes)[:3]
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
}
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
//
// system[0]: billing header (no cache_control)
// system[1]: agent identifier (no cache_control)
// system[2..]: user system messages (cache_control added when missing)
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
system := gjson.GetBytes(payload, "system")
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
billingText := generateBillingHeader(payload)
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
// No cache_control on the agent block. It is a cloaking artifact with zero cache
// value (the last system block is what actually triggers caching of all system content).
// Including any cache_control here creates an intra-system TTL ordering violation
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
if strictMode {
// Strict mode: replace all system messages with Claude Code prompt only
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
// Strict mode: billing header + agent identifier only
result := "[" + billingBlock + "," + agentBlock + "]"
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
return payload
}
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
if system.IsArray() {
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
}
return true
})
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}
} else {
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
// Non-strict mode: billing header + agent identifier + user system messages
// Skip if already injected
firstText := gjson.GetBytes(payload, "system.0.text").String()
if strings.HasPrefix(firstText, "x-anthropic-billing-header:") {
return payload
}
result := "[" + billingBlock + "," + agentBlock
if system.IsArray() {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
// Add cache_control to user system messages if not present.
// Do NOT add ttl — let it inherit the default (5m) to avoid
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw
if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
}
result += "," + partJSON
}
return true
})
}
result += "]"
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
return payload
}
@@ -1224,6 +1348,313 @@ func countCacheControls(payload []byte) int {
return count
}
func parsePayloadObject(payload []byte) (map[string]any, bool) {
if len(payload) == 0 {
return nil, false
}
var root map[string]any
if err := json.Unmarshal(payload, &root); err != nil {
return nil, false
}
return root, true
}
func marshalPayloadObject(original []byte, root map[string]any) []byte {
if root == nil {
return original
}
out, err := json.Marshal(root)
if err != nil {
return original
}
return out
}
func asObject(v any) (map[string]any, bool) {
obj, ok := v.(map[string]any)
return obj, ok
}
func asArray(v any) ([]any, bool) {
arr, ok := v.([]any)
return arr, ok
}
func countCacheControlsMap(root map[string]any) int {
count := 0
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if tools, ok := asArray(root["tools"]); ok {
for _, item := range tools {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
}
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
ccRaw, exists := obj["cache_control"]
if !exists {
return
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return
}
if *seen5m {
delete(cc, "ttl")
}
}
func findLastCacheControlIndex(arr []any) int {
last := -1
for idx, item := range arr {
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
last = idx
}
}
return last
}
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
for idx, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
delete(obj, "cache_control")
*excess--
}
}
}
func stripAllCacheControl(arr []any, excess *int) {
for _, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
func stripMessageCacheControl(messages []any, excess *int) {
for _, msg := range messages {
if *excess <= 0 {
return
}
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
//
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
// Within each section, blocks are evaluated in array order. A 5m (default) block
// followed by a 1h block at ANY later position is an error — including within
// the same section (e.g. system[1]=5m then system[3]=1h).
//
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
return payload
}
seen5m := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
normalizeTTLForBlock(obj, &seen5m)
}
}
}
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
}
}
}
}
return marshalPayloadObject(payload, root)
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
// so the total does not exceed the Anthropic API limit (currently 4).
//
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
// The most valuable breakpoints are:
// 1. Last tool — caches ALL tool definitions
// 2. Last system block — caches ALL system content
// 3. Recent messages — cache conversation context
//
// Removal priority (strip lowest-value first):
//
// Phase 1: system blocks earliest-first, preserving the last one.
// Phase 2: tool blocks earliest-first, preserving the last one.
// Phase 3: message content blocks earliest-first.
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
return payload
}
total := countCacheControlsMap(root)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
var system []any
if arr, ok := asArray(root["system"]); ok {
system = arr
}
var tools []any
if arr, ok := asArray(root["tools"]); ok {
tools = arr
}
var messages []any
if arr, ok := asArray(root["messages"]); ok {
messages = arr
}
if len(system) > 0 {
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(tools) > 0 {
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(messages) > 0 {
stripMessageCacheControl(messages, &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(system) > 0 {
stripAllCacheControl(system, &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(tools) > 0 {
stripAllCacheControl(tools, &excess)
}
return marshalPayloadObject(payload, root)
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.

View File

@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -348,3 +349,237 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
payload := []byte(`{
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`)
out := normalizeCacheControlTTL(payload)
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
}
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
}
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}},
{"name":"t3","cache_control":{"type":"ephemeral"}},
{"name":"t4","cache_control":{"type":"ephemeral"}},
{"name":"t5","cache_control":{"type":"ephemeral"}}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
}
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
t.Fatalf("last tool cache_control should be preserved when possible")
}
}
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"input_tokens":42}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
]
}`)
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-haiku-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected count_tokens request body to be captured")
}
if got := countCacheControls(seenBody); got > 4 {
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
}
if hasTTLOrderingViolation(seenBody) {
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
}
}
func hasTTLOrderingViolation(payload []byte) bool {
seen5m := false
violates := false
checkCC := func(cc gjson.Result) {
if !cc.Exists() || violates {
return
}
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
return
}
if seen5m {
violates = true
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
checkCC(tool.Get("cache_control"))
return !violates
})
}
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
content.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
return !violates
})
}
return violates
}
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func testClaudeExecutorInvalidCompressedErrorBody(
t *testing.T,
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
) {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
err := invoke(executor, auth, payload)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "failed to decode error response body") {
t.Fatalf("expected decode failure message, got: %v", err)
}
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
t.Fatalf("expected status code 400, got: %v", err)
}
}

View File

@@ -9,17 +9,18 @@ import (
"github.com/google/uuid"
)
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
// generateFakeUserID generates a fake user ID in Claude Code format.
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
func generateFakeUserID() string {
hexBytes := make([]byte, 32)
_, _ = rand.Read(hexBytes)
hexPart := hex.EncodeToString(hexBytes)
uuidPart := uuid.New().String()
return "user_" + hexPart + "_account__session_" + uuidPart
accountUUID := uuid.New().String()
sessionUUID := uuid.New().String()
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
}
// isValidUserID checks if a user ID matches Claude Code format.

View File

@@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
@@ -673,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter
}
return err
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil
}
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
return nil
}
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
resetAtTime := time.Unix(resetsAt, 0)
if resetAtTime.After(now) {
retryAfter := resetAtTime.Sub(now)
return &retryAfter
}
}
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
retryAfter := time.Duration(resetsInSeconds) * time.Second
return &retryAfter
}
return nil
}
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""

View File

@@ -0,0 +1,65 @@
package executor
import (
"net/http"
"strconv"
"testing"
"time"
)
func TestParseCodexRetryAfter(t *testing.T) {
now := time.Unix(1_700_000_000, 0)
t.Run("resets_in_seconds", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 123*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
}
})
t.Run("prefers resets_at", func(t *testing.T) {
resetAt := now.Add(5 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 5*time.Minute {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
}
})
t.Run("fallback when resets_at is past", func(t *testing.T) {
resetAt := now.Add(-1 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 77*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
}
})
t.Run("non-429 status code", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
t.Fatalf("expected nil for non-429, got %v", *got)
}
})
t.Run("non usage_limit_reached error type", func(t *testing.T) {
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
}
})
}
func itoa(v int64) string {
return strconv.FormatInt(v, 10)
}

View File

@@ -16,7 +16,6 @@ import (
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
@@ -81,7 +80,7 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(req)
applyGeminiCLIHeaders(req, "unknown")
return nil
}
@@ -189,7 +188,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
@@ -334,7 +333,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "text/event-stream")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
@@ -515,7 +514,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP)
applyGeminiCLIHeaders(reqHTTP, baseModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
@@ -738,21 +737,11 @@ func stringValue(m map[string]any, key string) string {
}
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
func applyGeminiCLIHeaders(r *http.Request) {
var ginHeaders http.Header
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1")
misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0")
misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata())
}
// geminiCLIClientMetadata returns a compact metadata string required by upstream.
func geminiCLIClientMetadata() string {
// Keep parity with CLI client defaults
return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
// User-Agent is always forced to the GeminiCLI format regardless of the client's value,
// so that upstream identifies the request as a native GeminiCLI client.
func applyGeminiCLIHeaders(r *http.Request, model string) {
r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model))
r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader)
}
// cliPreviewFallbackOrder returns preview model candidates for a base model.

View File

@@ -49,15 +49,8 @@ const (
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
// kiroUserAgent matches Amazon Q CLI style for User-Agent header
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
// kiroFullUserAgent is the complete x-amz-user-agent header (Amazon Q CLI style)
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
// Kiro IDE style headers for IDC auth
kiroIDEUserAgent = "aws-sdk-js/1.0.27 ua/2.1 os/win32#10.0.19044 lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E"
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.27"
kiroIDEAgentModeVibe = "vibe"
// kiroIDEAgentMode is the agent mode header value for Kiro IDE requests
kiroIDEAgentMode = "vibe"
// Socket retry configuration constants
// Maximum number of retry attempts for socket/network errors
@@ -87,20 +80,13 @@ var (
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
)
// Global FingerprintManager for dynamic User-Agent generation per token
// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
var (
globalFingerprintManager *kiroauth.FingerprintManager
globalFingerprintManagerOnce sync.Once
)
// getGlobalFingerprintManager returns the global FingerprintManager instance
func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = kiroauth.NewFingerprintManager()
log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
})
return globalFingerprintManager
// endpointAliases maps user preference values to canonical endpoint names.
var endpointAliases = map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
// retryConfig holds configuration for socket retry logic.
@@ -433,87 +419,41 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
return kiroEndpointConfigs
}
// Determine API region using shared resolution logic
region := resolveKiroAPIRegion(auth)
log.Debugf("kiro: using region %s", region)
// Build endpoint configs for the specified region
endpointConfigs := buildKiroEndpointConfigs(region)
// For IDC auth, use Q endpoint with AI_EDITOR origin
// IDC tokens work with Q endpoint using Bearer auth
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
if auth.Metadata != nil {
authMethod, _ := auth.Metadata["auth_method"].(string)
if strings.ToLower(authMethod) == "idc" {
log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region)
return endpointConfigs
}
}
// Check for preference
var preference string
if auth.Metadata != nil {
if p, ok := auth.Metadata["preferred_endpoint"].(string); ok {
preference = p
}
}
// Check attributes as fallback (e.g. from HTTP headers)
if preference == "" && auth.Attributes != nil {
preference = auth.Attributes["preferred_endpoint"]
}
configs := buildKiroEndpointConfigs(region)
preference := getAuthValue(auth, "preferred_endpoint")
if preference == "" {
return endpointConfigs
return configs
}
preference = strings.ToLower(strings.TrimSpace(preference))
targetName, ok := endpointAliases[preference]
if !ok {
return configs
}
// Create new slice to avoid modifying global state
var sorted []kiroEndpointConfig
var remaining []kiroEndpointConfig
for _, cfg := range endpointConfigs {
name := strings.ToLower(cfg.Name)
// Check for matches
// CodeWhisperer aliases: codewhisperer, ide
// AmazonQ aliases: amazonq, q, cli
isMatch := false
if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" {
isMatch = true
} else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" {
isMatch = true
}
if isMatch {
sorted = append(sorted, cfg)
var preferred, others []kiroEndpointConfig
for _, cfg := range configs {
if strings.ToLower(cfg.Name) == targetName {
preferred = append(preferred, cfg)
} else {
remaining = append(remaining, cfg)
others = append(others, cfg)
}
}
// If preference didn't match anything, return default
if len(sorted) == 0 {
return endpointConfigs
if len(preferred) == 0 {
return configs
}
// Combine: preferred first, then others
return append(sorted, remaining...)
return append(preferred, others...)
}
// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API.
type KiroExecutor struct {
cfg *config.Config
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
}
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
func isIDCAuth(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Metadata == nil {
return false
}
authMethod, _ := auth.Metadata["auth_method"].(string)
return strings.ToLower(authMethod) == "idc"
cfg *config.Config
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
profileArnMu sync.Mutex // Serializes profileArn fetches to prevent concurrent map writes
}
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
@@ -546,27 +486,22 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
// Identifier returns the unique identifier for this executor.
func (e *KiroExecutor) Identifier() string { return "kiro" }
// applyDynamicFingerprint applies token-specific fingerprint headers to the request
// For IDC auth, uses dynamic fingerprint-based User-Agent
// For other auth types, uses static Amazon Q CLI style headers
// applyDynamicFingerprint applies account-specific fingerprint headers to the request.
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
if isIDCAuth(auth) {
// Get token-specific fingerprint for dynamic UA generation
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
accountKey := getAccountKey(auth)
fp := kiroauth.GlobalFingerprintManager().GetFingerprint(accountKey)
// Use fingerprint-generated dynamic User-Agent
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
req.Header.Set("x-amzn-codewhisperer-optout", "true")
log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
} else {
// Use static Amazon Q CLI style headers for non-IDC auth
req.Header.Set("User-Agent", kiroUserAgent)
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
keyPrefix := accountKey
if len(keyPrefix) > 8 {
keyPrefix = keyPrefix[:8]
}
log.Debugf("kiro: using dynamic fingerprint for account %s (SDK:%s, OS:%s/%s, Kiro:%s)",
keyPrefix+"...", fp.StreamingSDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
}
// PrepareRequest prepares the HTTP request before execution.
@@ -609,17 +544,51 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
return httpClient.Do(httpReq)
}
// getTokenKey returns a unique key for rate limiting based on auth credentials.
// Uses auth ID if available, otherwise falls back to a hash of the access token.
func getTokenKey(auth *cliproxyauth.Auth) string {
// getAccountKey returns a stable account key for fingerprint lookup and rate limiting.
// Fallback order:
// 1) client_id / refresh_token (best account identity)
// 2) auth.ID (stable local auth record)
// 3) profile_arn (stable AWS profile identity)
// 4) access_token (least preferred but deterministic)
// 5) fixed anonymous seed
func getAccountKey(auth *cliproxyauth.Auth) string {
var clientID, refreshToken, profileArn string
if auth != nil && auth.Metadata != nil {
clientID, _ = auth.Metadata["client_id"].(string)
refreshToken, _ = auth.Metadata["refresh_token"].(string)
profileArn, _ = auth.Metadata["profile_arn"].(string)
}
if clientID != "" || refreshToken != "" {
return kiroauth.GetAccountKey(clientID, refreshToken)
}
if auth != nil && auth.ID != "" {
return auth.ID
return kiroauth.GenerateAccountKey(auth.ID)
}
accessToken, _ := kiroCredentials(auth)
if len(accessToken) > 16 {
return accessToken[:16]
if profileArn != "" {
return kiroauth.GenerateAccountKey(profileArn)
}
return accessToken
if accessToken, _ := kiroCredentials(auth); accessToken != "" {
return kiroauth.GenerateAccountKey(accessToken)
}
return kiroauth.GenerateAccountKey("kiro-anonymous")
}
// getAuthValue looks up a value by key in auth Metadata, then Attributes.
func getAuthValue(auth *cliproxyauth.Auth, key string) string {
if auth == nil {
return ""
}
if auth.Metadata != nil {
if v, ok := auth.Metadata[key].(string); ok && v != "" {
return strings.ToLower(strings.TrimSpace(v))
}
}
if auth.Attributes != nil {
if v := auth.Attributes[key]; v != "" {
return strings.ToLower(strings.TrimSpace(v))
}
}
return ""
}
// Execute sends the request to Kiro API and returns the response.
@@ -631,7 +600,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
@@ -693,6 +662,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
kiroModelID := e.mapModelToKiro(req.Model)
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
if profileArn == "" {
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
profileArn = fetched
}
}
// Determine agentic mode and effective profile ARN using helper functions
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
@@ -749,7 +725,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
}
// Kiro-specific headers
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
// Apply dynamic fingerprint-based headers
@@ -1060,7 +1036,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
@@ -1126,6 +1102,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
kiroModelID := e.mapModelToKiro(req.Model)
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
if profileArn == "" {
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
profileArn = fetched
}
}
// Determine agentic mode and effective profile ARN using helper functions
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
@@ -1185,7 +1168,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
}
// Kiro-specific headers
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
// Apply dynamic fingerprint-based headers
@@ -1647,62 +1630,23 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) {
return isAgentic, isChatOnly
}
// getEffectiveProfileArn determines if profileArn should be included based on auth method.
// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC).
//
// Detection logic (matching kiro-openai-gateway):
// 1. Check auth_method field: "builder-id" or "idc"
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
if auth != nil && auth.Metadata != nil {
// Check 1: auth_method field (from CLIProxyAPI tokens)
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 2: auth_type field (from kiro-cli tokens)
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature)
_, hasClientID := auth.Metadata["client_id"].(string)
_, hasClientSecret := auth.Metadata["client_secret"].(string)
if hasClientID && hasClientSecret {
return "" // AWS SSO OIDC - don't include profileArn
}
}
return profileArn
}
// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
// and logs a warning if profileArn is missing for non-builder-id auth.
// This consolidates the auth_method check that was previously done separately.
//
// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors.
// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn.
//
// Detection logic (matching kiro-openai-gateway):
// 1. Check auth_method field: "builder-id" or "idc"
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
// getEffectiveProfileArnWithWarning suppresses profileArn for builder-id and AWS SSO OIDC auth.
// Builder-id users (auth_method == "builder-id") and AWS SSO OIDC users (auth_type == "aws_sso_oidc")
// don't need profileArn — sending it causes 403 errors.
// For all other auth methods (e.g. social auth), profileArn is returned as-is,
// with a warning logged if it is empty.
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
if auth != nil && auth.Metadata != nil {
// Check 1: auth_method field (from CLIProxyAPI tokens)
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
return "" // AWS SSO OIDC - don't include profileArn
// Check 1: auth_method field, skip for builder-id only
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
return ""
}
// Check 2: auth_type field (from kiro-cli tokens)
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway)
_, hasClientID := auth.Metadata["client_id"].(string)
_, hasClientSecret := auth.Metadata["client_secret"].(string)
if hasClientID && hasClientSecret {
return "" // AWS SSO OIDC - don't include profileArn
}
}
// For social auth (Kiro Desktop), profileArn is required
// For social auth and IDC, profileArn is required
if profileArn == "" {
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
}
@@ -3999,6 +3943,51 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
return nil
}
// fetchAndSaveProfileArn fetches profileArn from API if missing, updates auth and persists to file.
func (e *KiroExecutor) fetchAndSaveProfileArn(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) string {
if auth == nil || auth.Metadata == nil {
return ""
}
// Skip for Builder ID - they don't have profiles
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
log.Debugf("kiro executor: skipping profileArn fetch for builder-id auth")
return ""
}
e.profileArnMu.Lock()
defer e.profileArnMu.Unlock()
// Double-check: another goroutine may have already fetched and saved the profileArn
if arn, ok := auth.Metadata["profile_arn"].(string); ok && arn != "" {
return arn
}
clientID, _ := auth.Metadata["client_id"].(string)
refreshToken, _ := auth.Metadata["refresh_token"].(string)
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
profileArn := ssoClient.FetchProfileArn(ctx, accessToken, clientID, refreshToken)
if profileArn == "" {
log.Debugf("kiro executor: FetchProfileArn returned no profiles")
return ""
}
auth.Metadata["profile_arn"] = profileArn
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["profile_arn"] = profileArn
if err := e.persistRefreshedAuth(auth); err != nil {
log.Warnf("kiro executor: failed to persist profileArn: %v", err)
} else {
log.Infof("kiro executor: fetched and saved profileArn: %s", profileArn)
}
return profileArn
}
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
@@ -4728,7 +4717,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
kiroStream, err := e.executeStreamWithRetry(
ctx, auth, req, opts, accessToken, effectiveProfileArn,
@@ -4770,7 +4759,7 @@ func (e *KiroExecutor) callKiroDirectStream(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
var streamErr error
@@ -4819,7 +4808,7 @@ func (e *KiroExecutor) executeNonStreamFallback(
kiroModelID := e.mapModelToKiro(req.Model)
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
var err error

View File

@@ -0,0 +1,423 @@
package executor
import (
"fmt"
"testing"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestBuildKiroEndpointConfigs(t *testing.T) {
tests := []struct {
name string
region string
expectedURL string
expectedOrigin string
expectedName string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "us-east-1",
region: "us-east-1",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "eu-west-1",
region: "eu-west-1",
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configs := buildKiroEndpointConfigs(tt.region)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Check primary endpoint (AmazonQ)
primary := configs[0]
if primary.URL != tt.expectedURL {
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
}
if primary.Origin != tt.expectedOrigin {
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
}
if primary.Name != tt.expectedName {
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
}
if primary.AmzTarget != "" {
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
}
// Check fallback endpoint (CodeWhisperer)
fallback := configs[1]
if fallback.Name != "CodeWhisperer" {
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
}
// CodeWhisperer fallback uses the same region as Q endpoint
expectedRegion := tt.region
if expectedRegion == "" {
expectedRegion = kiroDefaultRegion
}
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
if fallback.URL != expectedFallbackURL {
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
}
if fallback.AmzTarget == "" {
t.Error("fallback AmzTarget should NOT be empty")
}
})
}
}
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
configs := getKiroEndpointConfigs(nil)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Should return default us-east-1 configs
if configs[0].Name != "AmazonQ" {
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
}
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"api_region": "eu-central-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
// api_region should take precedence over profile_arn
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
tests := []struct {
name string
preference string
expectedFirstName string
}{
{
name: "Prefer codewhisperer",
preference: "codewhisperer",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer ide (alias for codewhisperer)",
preference: "ide",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer amazonq",
preference: "amazonq",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer q (alias for amazonq)",
preference: "q",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer cli (alias for amazonq)",
preference: "cli",
expectedFirstName: "AmazonQ",
},
{
name: "Unknown preference - no reordering",
preference: "unknown",
expectedFirstName: "AmazonQ",
},
{
name: "Empty preference - no reordering",
preference: "",
expectedFirstName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"preferred_endpoint": tt.preference,
},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != tt.expectedFirstName {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
}
})
}
}
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
// Test that preferred_endpoint can also come from Attributes
auth := &cliproxyauth.Auth{
Metadata: map[string]any{},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != "CodeWhisperer" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
}
}
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
// Metadata should take precedence
if configs[0].Name != "AmazonQ" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
}
}
func TestGetAuthValue(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
key string
expected string
}{
{
name: "From metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "From attributes (fallback)",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Metadata takes precedence",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "Key not found",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"other_key": "value"},
Attributes: map[string]string{"another_key": "value"},
},
key: "test_key",
expected: "",
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Both nil",
auth: &cliproxyauth.Auth{},
key: "test_key",
expected: "",
},
{
name: "Value is trimmed and lowercased",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
},
key: "test_key",
expected: "upper_value",
},
{
name: "Empty string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": ""},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Non-string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": 123},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAuthValue(tt.auth, tt.key)
if result != tt.expected {
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
checkFn func(t *testing.T, result string)
}{
{
name: "From client_id",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"client_id": "test-client-id-123",
"refresh_token": "test-refresh-token-456",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "From refresh_token only",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"refresh_token": "test-refresh-token-789",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "Nil auth",
auth: nil,
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Empty metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{},
},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAccountKey(tt.auth)
tt.checkFn(t, result)
})
}
}
func TestEndpointAliases(t *testing.T) {
// Verify all expected aliases are defined
expectedAliases := map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
for alias, target := range expectedAliases {
if actual, ok := endpointAliases[alias]; !ok {
t.Errorf("missing alias %q", alias)
} else if actual != target {
t.Errorf("alias %q = %q, want %q", alias, actual, target)
}
}
// Verify no unexpected aliases
if len(endpointAliases) != len(expectedAliases) {
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
}
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -22,9 +23,151 @@ import (
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{
"insufficient_quota": {},
"quota_exceeded": {},
}
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
// Qwen has a limit of 60 requests per minute per account.
var qwenRateLimiter = struct {
sync.Mutex
requests map[string][]time.Time // authID -> request timestamps
}{
requests: make(map[string][]time.Time),
}
// redactAuthID returns a redacted version of the auth ID for safe logging.
// Keeps a small prefix/suffix to allow correlation across events.
func redactAuthID(id string) string {
if id == "" {
return ""
}
if len(id) <= 8 {
return id
}
return id[:4] + "..." + id[len(id)-4:]
}
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
func checkQwenRateLimit(authID string) error {
if authID == "" {
// Empty authID should not bypass rate limiting in production
// Use debug level to avoid log spam for certain auth flows
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
return nil
}
now := time.Now()
windowStart := now.Add(-qwenRateLimitWindow)
qwenRateLimiter.Lock()
defer qwenRateLimiter.Unlock()
// Get and filter timestamps within the window
timestamps := qwenRateLimiter.requests[authID]
var validTimestamps []time.Time
for _, ts := range timestamps {
if ts.After(windowStart) {
validTimestamps = append(validTimestamps, ts)
}
}
// Always prune expired entries to prevent memory leak
// Delete empty entries, otherwise update with pruned slice
if len(validTimestamps) == 0 {
delete(qwenRateLimiter.requests, authID)
}
// Check if rate limit exceeded
if len(validTimestamps) >= qwenRateLimitPerMin {
// Calculate when the oldest request will expire
oldestInWindow := validTimestamps[0]
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
retryAfterSec := int(retryAfter.Seconds())
return statusErr{
code: http.StatusTooManyRequests,
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
retryAfter: &retryAfter,
}
}
// Record this request and update the map with pruned timestamps
validTimestamps = append(validTimestamps, now)
qwenRateLimiter.requests[authID] = validTimestamps
return nil
}
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
func isQwenQuotaError(body []byte) bool {
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
// Primary check: exact match on error.code or error.type (most reliable)
if _, ok := qwenQuotaCodes[code]; ok {
return true
}
if _, ok := qwenQuotaCodes[errType]; ok {
return true
}
// Fallback: check message only if code/type don't match (less reliable)
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
strings.Contains(msg, "free allocated quota exceeded") {
return true
}
return false
}
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
// Returns the appropriate status code and retryAfter duration for statusErr.
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
errCode = httpCode
// Only check quota errors for expected status codes to avoid false positives
// Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
// Qwen's daily quota resets at 00:00 Beijing time.
func timeUntilNextDay() time.Duration {
now := time.Now()
nowLocal := now.In(qwenBeijingLoc)
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
return tomorrow.Sub(now)
}
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
// If access token is unavailable, it falls back to legacy via ClientAdapter.
type QwenExecutor struct {
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)

View File

@@ -1,8 +1,7 @@
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
//
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
// (low/medium/high). The provider strips any existing thinking config and applies
// the unified ThinkingConfig in OpenAI format.
// Kimi models use the OpenAI-compatible reasoning_effort format for enabled thinking
// levels, but use thinking.type=disabled when thinking is explicitly turned off.
package kimi
import (
@@ -17,8 +16,8 @@ import (
// Applier implements thinking.ProviderApplier for Kimi models.
//
// Kimi-specific behavior:
// - Output format: reasoning_effort (string: low/medium/high)
// - Uses OpenAI-compatible format
// - Enabled thinking: reasoning_effort (string levels)
// - Disabled thinking: thinking.type="disabled"
// - Supports budget-to-level conversion
type Applier struct{}
@@ -35,11 +34,19 @@ func init() {
// Apply applies thinking configuration to Kimi request body.
//
// Expected output format:
// Expected output format (enabled):
//
// {
// "reasoning_effort": "high"
// }
//
// Expected output format (disabled):
//
// {
// "thinking": {
// "type": "disabled"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleKimi(body, config)
@@ -60,8 +67,13 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
effort = string(config.Level)
case thinking.ModeNone:
// Kimi uses "none" to disable thinking
effort = string(thinking.LevelNone)
// Respect clamped fallback level for models that cannot disable thinking.
if config.Level != "" && config.Level != thinking.LevelNone {
effort = string(config.Level)
break
}
// Kimi requires explicit disabled thinking object.
return applyDisabledThinking(body)
case thinking.ModeBudget:
// Convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
@@ -79,12 +91,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
if effort == "" {
return body, nil
}
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
if err != nil {
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
}
return result, nil
return applyReasoningEffort(body, effort)
}
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
@@ -101,7 +108,9 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
}
effort = string(config.Level)
case thinking.ModeNone:
effort = string(thinking.LevelNone)
if config.Level == "" || config.Level == thinking.LevelNone {
return applyDisabledThinking(body)
}
if config.Level != "" {
effort = string(config.Level)
}
@@ -118,9 +127,33 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
return body, nil
}
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
if err != nil {
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
return applyReasoningEffort(body, effort)
}
func applyReasoningEffort(body []byte, effort string) ([]byte, error) {
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
if errDeleteThinking != nil {
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
}
result, errSetEffort := sjson.SetBytes(result, "reasoning_effort", effort)
if errSetEffort != nil {
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", errSetEffort)
}
return result, nil
}
func applyDisabledThinking(body []byte) ([]byte, error) {
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
if errDeleteThinking != nil {
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
}
result, errDeleteEffort := sjson.DeleteBytes(result, "reasoning_effort")
if errDeleteEffort != nil {
return body, fmt.Errorf("kimi thinking: failed to clear reasoning_effort: %w", errDeleteEffort)
}
result, errSetType := sjson.SetBytes(result, "thinking.type", "disabled")
if errSetType != nil {
return body, fmt.Errorf("kimi thinking: failed to set thinking.type: %w", errSetType)
}
return result, nil
}

View File

@@ -0,0 +1,72 @@
package kimi
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
)
func TestApply_ModeNone_UsesDisabledThinking(t *testing.T) {
applier := NewApplier()
modelInfo := &registry.ModelInfo{
ID: "kimi-k2.5",
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
}
body := []byte(`{"model":"kimi-k2.5","reasoning_effort":"none","thinking":{"type":"enabled","budget_tokens":2048}}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
if gjson.GetBytes(out, "reasoning_effort").Exists() {
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
}
}
func TestApply_ModeLevel_UsesReasoningEffort(t *testing.T) {
applier := NewApplier()
modelInfo := &registry.ModelInfo{
ID: "kimi-k2.5",
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
}
body := []byte(`{"model":"kimi-k2.5","thinking":{"type":"disabled"}}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeLevel, Level: thinking.LevelHigh}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "reasoning_effort").String(); got != "high" {
t.Fatalf("reasoning_effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking").Exists() {
t.Fatalf("thinking should be removed when reasoning_effort is used, body=%s", string(out))
}
}
func TestApply_UserDefinedModeNone_UsesDisabledThinking(t *testing.T) {
applier := NewApplier()
modelInfo := &registry.ModelInfo{
ID: "custom-kimi-model",
UserDefined: true,
}
body := []byte(`{"model":"custom-kimi-model","reasoning_effort":"none"}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
}
if gjson.GetBytes(out, "reasoning_effort").Exists() {
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
}
}

View File

@@ -10,53 +10,10 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// validReasoningEffortLevels contains the standard values accepted by the
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
// auto) are NOT in this set and must be clamped before use.
var validReasoningEffortLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
}
// clampReasoningEffort maps any thinking level string to a value that is safe
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
// mapped to the nearest standard equivalent.
//
// Mapping rules:
// - none / low / medium / high → returned as-is (already valid)
// - xhigh → "high" (nearest lower standard level)
// - minimal → "low" (nearest higher standard level)
// - auto → "medium" (reasonable default)
// - anything else → "medium" (safe default)
func clampReasoningEffort(level string) string {
if _, ok := validReasoningEffortLevels[level]; ok {
return level
}
var clamped string
switch level {
case string(thinking.LevelXHigh):
clamped = string(thinking.LevelHigh)
case string(thinking.LevelMinimal):
clamped = string(thinking.LevelLow)
case string(thinking.LevelAuto):
clamped = string(thinking.LevelMedium)
default:
clamped = string(thinking.LevelMedium)
}
log.WithFields(log.Fields{
"original": level,
"clamped": clamped,
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
return clamped
}
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
return result, nil
}
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}

View File

@@ -37,6 +37,11 @@ func StripThinkingConfig(body []byte, provider string) []byte {
paths = []string{"request.generationConfig.thinkingConfig"}
case "openai":
paths = []string{"reasoning_effort"}
case "kimi":
paths = []string{
"reasoning_effort",
"thinking",
}
case "codex":
paths = []string{"reasoning.effort"}
case "iflow":

View File

@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
if len(frResults) == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
@@ -349,7 +400,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
hasTools := toolDeclCount > 0
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
thinkingType := thinkingResult.Get("type").String()
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive" || thinkingType == "auto")
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
if hasTools && hasThinking && isClaudeThinking {
@@ -389,8 +440,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive":
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
case "adaptive", "auto":
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
// to model-specific max capability.
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)

View File

@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mime_type").String() != "image/png" {
t.Error("mime_type mismatch")
if inlineData.Get("mimeType").String() != "image/png" {
t.Error("mimeType mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
// tool_result with array content containing text + image should place
// image data inside functionResponse.parts as inlineData, not as a
// sibling part in the outer content (to avoid base64 context bloat).
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-123-456",
"content": [
{
"type": "text",
"text": "File content here"
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
// Image should be inside functionResponse.parts, not as outer sibling part
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Text content should be in response.result
resultText := funcResp.Get("response.result.text").String()
if resultText != "File content here" {
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
// Image should NOT be in outer parts (only functionResponse part should exist)
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
// tool_result with single image object as content should place
// image data inside functionResponse.parts, not as outer sibling part.
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-789-012",
"content": {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRgABAQ=="
}
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// response.result should be empty (image only)
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/jpeg" {
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
}
// Image should NOT be in outer parts
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
// tool_result with array content: 2 text items + 2 images
// All images go into functionResponse.parts, texts into response.result array
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Multi-001",
"content": [
{"type": "text", "text": "First text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
},
{"type": "text", "text": "Second text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Multiple text items => response.result is an array
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// Both images should be in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
}
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
}
// Only 1 outer part (the functionResponse itself)
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
// tool_result with only images (no text) — response.result should be empty string
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "ImgOnly-001",
"content": [
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// No text => response.result should be empty string
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
}
// Both images in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("first image mimeType mismatch")
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
t.Error("second image mimeType mismatch")
}
// Only 1 outer part
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
// image with source.type != "base64" should be treated as non-image (falls through)
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NotB64-001",
"content": [
{"type": "text", "text": "some output"},
{
"type": "image",
"source": {"type": "url", "url": "https://example.com/img.png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Non-base64 image is treated as non-image, so it goes into the filtered results
// along with the text item. Since there are 2 non-image items, result is array.
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// No functionResponse.parts (no base64 images collected)
if funcResp.Get("parts").Exists() {
t.Error("functionResponse.parts should NOT exist when no base64 images")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
// image with source.type=base64 but missing data field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoData-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image (type check passes),
// but data field is missing => inlineData has mimeType but no data
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("mimeType should still be set")
}
if imgParts[0].Get("inlineData.data").Exists() {
t.Error("data should not exist when source.data is missing")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
// image with source.type=base64 but missing media_type field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoMime-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "data": "AAAA"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image,
// but media_type is missing => inlineData has data but no mimeType
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").Exists() {
t.Error("mimeType should not exist when media_type is missing")
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Error("data should still be set")
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{

View File

@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
}
}
}
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
// When functionResponse contains a "parts" field with inlineData (from Claude
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
// so extra fields like "parts" survive the pipeline.
input := `{
"model": "claude-opus-4-6-thinking",
"request": {
"contents": [
{
"role": "model",
"parts": [
{
"functionCall": {"name": "screenshot", "args": {}}
}
]
},
{
"role": "function",
"parts": [
{
"functionResponse": {
"id": "tool-001",
"name": "screenshot",
"response": {"result": "Screenshot taken"},
"parts": [
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
]
}
}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
// Find the function response content (role=function)
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// The functionResponse should be preserved with its parts field
funcResp := funcContent.Get("parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist in output")
}
// Verify the parts field with inlineData is preserved
inlineParts := funcResp.Get("parts").Array()
if len(inlineParts) != 1 {
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
}
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
}
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
}
// Verify response.result is also preserved
if funcResp.Get("response.result").String() != "Screenshot taken" {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}

View File

@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")
@@ -187,7 +192,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
@@ -201,7 +206,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
@@ -235,7 +240,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++

View File

@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
}
return true
})

View File

@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage {
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])

View File

@@ -46,15 +46,23 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
if systemsResult.IsArray() {
systemResults := systemsResult.Array()
message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
systemTypeResult := systemResult.Get("type")
if systemTypeResult.String() == "text" {
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String())
text := systemResult.Get("text").String()
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
}
template, _ = sjson.SetRaw(template, "input.-1", message)
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
}
}
// Process messages and transform their contents to appropriate formats.
@@ -222,8 +230,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
reasoningEffort = effort
}
}
case "adaptive":
// Claude adaptive means "enable with max capacity"; keep it as highest level
case "adaptive", "auto":
// Claude adaptive/auto means "enable with max capacity"; keep it as highest level
// and let ApplyThinking normalize per target model capability.
reasoningEffort = string(thinking.LevelXHigh)
case "disabled":

View File

@@ -22,8 +22,8 @@ var (
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
}

View File

@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
// Files are not specified in examples; skip for now
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}

View File

@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
// Delete the user field as it is not supported by the Codex upstream.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
return rawJSON
}
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
// for Codex upstream compatibility.
//
// Codex /responses currently rejects context_management with:
// {"detail":"Unsupported parameter: context_management"}.
//
// Compatibility strategy:
// 1) Remove context_management before forwarding to Codex upstream.
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
return rawJSON
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
return rawJSON
}
// convertSystemRoleToDeveloper traverses the input array and converts any message items
// with role "system" to role "developer". This is necessary because Codex API does not
// accept "system" role in the input array.

View File

@@ -264,19 +264,57 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
}
}
func TestUserFieldDeletion(t *testing.T) {
func TestUserFieldDeletion(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"user": "test-user",
"input": [{"role": "user", "content": "Hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Verify user field is deleted
userField := gjson.Get(outputStr, "user")
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Verify user field is deleted
userField := gjson.Get(outputStr, "user")
if userField.Exists() {
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
}
}
func TestContextManagementCompactionCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"context_management": [
{
"type": "compaction",
"compact_threshold": 12000
}
],
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "context_management").Exists() {
t.Fatalf("context_management should be removed for Codex compatibility")
}
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"truncation": "disabled",
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}

View File

@@ -180,8 +180,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive":
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
case "adaptive", "auto":
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
// to model-specific max capability.
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)

View File

@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")

View File

@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -161,8 +161,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive":
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
case "adaptive", "auto":
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
// to model-specific max capability.
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)

View File

@@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")

View File

@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())

View File

@@ -38,10 +38,12 @@ type KiroInferenceConfig struct {
// KiroConversationState holds the conversation context
type KiroConversationState struct {
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
ConversationID string `json:"conversationId"`
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
AgentContinuationID string `json:"agentContinuationId,omitempty"`
AgentTaskType string `json:"agentTaskType,omitempty"`
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
ConversationID string `json:"conversationId"`
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
}
// KiroCurrentMessage wraps the current user message
@@ -293,10 +295,18 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
}
}
// Session IDs: extract from messages[].additional_kwargs (LangChain format) or random
conversationID := extractMetadataFromMessages(messages, "conversationId")
continuationID := extractMetadataFromMessages(messages, "continuationId")
if conversationID == "" {
conversationID = uuid.New().String()
}
payload := KiroPayload{
ConversationState: KiroConversationState{
AgentTaskType: "vibe",
ChatTriggerType: "MANUAL",
ConversationID: uuid.New().String(),
ConversationID: conversationID,
CurrentMessage: currentMessage,
History: history,
},
@@ -304,6 +314,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
InferenceConfig: inferenceConfig,
}
// Only set AgentContinuationID if client provided
if continuationID != "" {
payload.ConversationState.AgentContinuationID = continuationID
}
result, err := json.Marshal(payload)
if err != nil {
log.Debugf("kiro: failed to marshal payload: %v", err)
@@ -329,6 +344,18 @@ func normalizeOrigin(origin string) string {
}
}
// extractMetadataFromMessages extracts metadata from messages[].additional_kwargs (LangChain format).
// Searches from the last message backwards, returns empty string if not found.
func extractMetadataFromMessages(messages gjson.Result, key string) string {
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" {
return val.String()
}
}
return ""
}
// extractSystemPrompt extracts system prompt from Claude request
func extractSystemPrompt(claudeBody []byte) string {
systemField := gjson.GetBytes(claudeBody, "system")

View File

@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
"execute_python": true,
}
// RequiredFieldsByTool maps tool names to their required fields.
// If any of these fields are missing, the tool input is considered truncated.
var RequiredFieldsByTool = map[string][]string{
"Write": {"file_path", "content"},
"write_to_file": {"path", "content"},
"fsWrite": {"path", "content"},
"create_file": {"path", "content"},
"edit_file": {"path"},
"apply_diff": {"path", "diff"},
"str_replace_editor": {"path", "old_str", "new_str"},
"Bash": {"command"},
"execute": {"command"},
"run_command": {"command"},
// RequiredFieldsByTool maps tool names to their required field groups.
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
// All groups must be satisfied for the tool input to be considered valid.
//
// Example:
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
var RequiredFieldsByTool = map[string][][]string{
"Write": {{"file_path"}, {"content"}},
"write_to_file": {{"path"}, {"content"}},
"fsWrite": {{"path"}, {"content"}},
"create_file": {{"path"}, {"content"}},
"edit_file": {{"path"}},
"apply_diff": {{"path"}, {"diff"}},
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
"Bash": {{"cmd", "command"}},
"execute": {{"command"}},
"run_command": {{"command"}},
}
// DetectTruncation checks if the tool use input appears to be truncated.
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
// Scenario 3: JSON parsed but critical fields are missing
if parsedInput != nil {
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
if hasRequirements {
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
if len(missingFields) > 0 {
info.IsTruncated = true
info.TruncationType = TruncationTypeMissingFields
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
return fields
}
// findMissingRequiredFields checks which required fields are missing from the parsed input.
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
// findMissingRequiredFields checks which required field groups are unsatisfied.
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
var missing []string
for _, field := range required {
if _, exists := parsed[field]; !exists {
missing = append(missing, field)
for _, group := range requiredGroups {
satisfied := false
for _, field := range group {
if _, exists := parsed[field]; exists {
satisfied = true
break
}
}
if !satisfied {
missing = append(missing, strings.Join(group, "/"))
}
}
return missing

View File

@@ -36,10 +36,12 @@ type KiroInferenceConfig struct {
// KiroConversationState holds the conversation context
type KiroConversationState struct {
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
ConversationID string `json:"conversationId"`
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
AgentContinuationID string `json:"agentContinuationId,omitempty"`
AgentTaskType string `json:"agentTaskType,omitempty"`
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
ConversationID string `json:"conversationId"`
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
}
// KiroCurrentMessage wraps the current user message
@@ -297,10 +299,18 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
}
}
// Session IDs: extract from messages[].additional_kwargs (LangChain format) or random
conversationID := extractMetadataFromMessages(messages, "conversationId")
continuationID := extractMetadataFromMessages(messages, "continuationId")
if conversationID == "" {
conversationID = uuid.New().String()
}
payload := KiroPayload{
ConversationState: KiroConversationState{
AgentTaskType: "vibe",
ChatTriggerType: "MANUAL",
ConversationID: uuid.New().String(),
ConversationID: conversationID,
CurrentMessage: currentMessage,
History: history,
},
@@ -308,6 +318,11 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
InferenceConfig: inferenceConfig,
}
// Only set AgentContinuationID if client provided
if continuationID != "" {
payload.ConversationState.AgentContinuationID = continuationID
}
result, err := json.Marshal(payload)
if err != nil {
log.Debugf("kiro-openai: failed to marshal payload: %v", err)
@@ -333,6 +348,18 @@ func normalizeOrigin(origin string) string {
}
}
// extractMetadataFromMessages extracts metadata from messages[].additional_kwargs (LangChain format).
// Searches from the last message backwards, returns empty string if not found.
func extractMetadataFromMessages(messages gjson.Result, key string) string {
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" {
return val.String()
}
}
return ""
}
// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages
func extractSystemPromptFromOpenAI(messages gjson.Result) string {
if !messages.IsArray() {
@@ -578,6 +605,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Truncate history if too long to prevent Kiro API errors
history = truncateHistoryIfNeeded(history)
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
return history, currentUserMsg, currentToolResults
}
@@ -593,6 +621,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
return history[len(history)-kiroMaxHistoryMessages:]
}
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
// Remove tool results with no matching tool_use in retained history.
// This happens after truncation when the assistant turn that produced tool_use
// is dropped but a later user/tool_result survives.
validToolUseIDs := make(map[string]bool)
for _, h := range history {
if h.AssistantResponseMessage == nil {
continue
}
for _, tu := range h.AssistantResponseMessage.ToolUses {
validToolUseIDs[tu.ToolUseID] = true
}
}
for i, h := range history {
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
continue
}
ctx := h.UserInputMessage.UserInputMessageContext
if len(ctx.ToolResults) == 0 {
continue
}
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
for _, tr := range ctx.ToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
}
ctx.ToolResults = filtered
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
h.UserInputMessage.UserInputMessageContext = nil
}
}
if len(currentToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(currentToolResults))
for _, tr := range currentToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
}
if len(filtered) != len(currentToolResults) {
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
}
currentToolResults = filtered
}
return history, currentToolResults
}
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
content := msg.Get("content")

View File

@@ -384,3 +384,57 @@ func TestAssistantEndsConversation(t *testing.T) {
t.Error("Expected a 'Continue' message to be created when assistant is last")
}
}
func TestFilterOrphanedToolResults_RemovesHistoryAndCurrentOrphans(t *testing.T) {
history := []KiroHistoryMessage{
{
AssistantResponseMessage: &KiroAssistantResponseMessage{
Content: "assistant",
ToolUses: []KiroToolUse{
{ToolUseID: "keep-1", Name: "Read", Input: map[string]interface{}{}},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-with-mixed-results",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-1", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-only-orphans",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "orphan-2", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
}
currentToolResults := []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-3", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
}
filteredHistory, filteredCurrent := filterOrphanedToolResults(history, currentToolResults)
ctx1 := filteredHistory[1].UserInputMessage.UserInputMessageContext
if ctx1 == nil || len(ctx1.ToolResults) != 1 || ctx1.ToolResults[0].ToolUseID != "keep-1" {
t.Fatalf("expected mixed history message to keep only keep-1, got: %+v", ctx1)
}
if filteredHistory[2].UserInputMessage.UserInputMessageContext != nil {
t.Fatalf("expected orphan-only history context to be removed")
}
if len(filteredCurrent) != 1 || filteredCurrent[0].ToolUseID != "keep-1" {
t.Fatalf("expected current tool results to keep only keep-1, got: %+v", filteredCurrent)
}
}

View File

@@ -75,8 +75,8 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
}
case "adaptive":
// Claude adaptive means "enable with max capacity"; keep it as highest level
case "adaptive", "auto":
// Claude adaptive/auto means "enable with max capacity"; keep it as highest level
// and let ApplyThinking normalize per target model capability.
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
case "disabled":

View File

@@ -127,7 +127,8 @@ func (w *Watcher) reloadConfig() bool {
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias))
retryConfigChanged := oldConfig != nil && (oldConfig.RequestRetry != newConfig.RequestRetry || oldConfig.MaxRetryInterval != newConfig.MaxRetryInterval || oldConfig.MaxRetryCredentials != newConfig.MaxRetryCredentials)
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias) || retryConfigChanged)
log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)

View File

@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.RequestRetry != newCfg.RequestRetry {
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
}
if oldCfg.MaxRetryCredentials != newCfg.MaxRetryCredentials {
changes = append(changes, fmt.Sprintf("max-retry-credentials: %d -> %d", oldCfg.MaxRetryCredentials, newCfg.MaxRetryCredentials))
}
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
}

View File

@@ -223,6 +223,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
UsageStatisticsEnabled: false,
DisableCooling: false,
RequestRetry: 1,
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
@@ -246,6 +247,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
UsageStatisticsEnabled: true,
DisableCooling: true,
RequestRetry: 2,
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
@@ -283,6 +285,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "disable-cooling: false -> true")
expectContains(t, details, "request-log: false -> true")
expectContains(t, details, "request-retry: 1 -> 2")
expectContains(t, details, "max-retry-credentials: 1 -> 3")
expectContains(t, details, "max-retry-interval: 1 -> 3")
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
expectContains(t, details, "ws-auth: false -> true")
@@ -309,6 +312,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
UsageStatisticsEnabled: false,
DisableCooling: false,
RequestRetry: 1,
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
@@ -361,6 +365,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
UsageStatisticsEnabled: true,
DisableCooling: true,
RequestRetry: 2,
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
@@ -419,6 +424,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
expectContains(t, changes, "usage-statistics-enabled: false -> true")
expectContains(t, changes, "disable-cooling: false -> true")
expectContains(t, changes, "request-retry: 1 -> 2")
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
expectContains(t, changes, "max-retry-interval: 1 -> 3")
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
expectContains(t, changes, "ws-auth: false -> true")

View File

@@ -1239,6 +1239,67 @@ func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
}
}
func TestReloadConfigTriggersCallbackForMaxRetryCredentialsChange(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
oldCfg := &config.Config{
AuthDir: authDir,
MaxRetryCredentials: 0,
RequestRetry: 1,
MaxRetryInterval: 5,
}
newCfg := &config.Config{
AuthDir: authDir,
MaxRetryCredentials: 2,
RequestRetry: 1,
MaxRetryInterval: 5,
}
data, errMarshal := yaml.Marshal(newCfg)
if errMarshal != nil {
t.Fatalf("failed to marshal config: %v", errMarshal)
}
if errWrite := os.WriteFile(configPath, data, 0o644); errWrite != nil {
t.Fatalf("failed to write config: %v", errWrite)
}
callbackCalls := 0
callbackMaxRetryCredentials := -1
w := &Watcher{
configPath: configPath,
authDir: authDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(cfg *config.Config) {
callbackCalls++
if cfg != nil {
callbackMaxRetryCredentials = cfg.MaxRetryCredentials
}
},
}
w.SetConfig(oldCfg)
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
if callbackCalls != 1 {
t.Fatalf("expected reload callback to be called once, got %d", callbackCalls)
}
if callbackMaxRetryCredentials != 2 {
t.Fatalf("expected callback MaxRetryCredentials=2, got %d", callbackMaxRetryCredentials)
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if w.config == nil || w.config.MaxRetryCredentials != 2 {
t.Fatalf("expected watcher config MaxRetryCredentials=2, got %+v", w.config)
}
}
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")

View File

@@ -717,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return
}
if len(chunk.Payload) > 0 {
if handlerType == "openai-response" {
if err := validateSSEDataJSON(chunk.Payload); err != nil {
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
return
}
}
sentPayload = true
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
@@ -728,6 +734,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, upstreamHeaders, errChan
}
func validateSSEDataJSON(chunk []byte) error {
for _, line := range bytes.Split(chunk, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[5:])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if json.Valid(data) {
continue
}
const max = 512
preview := data
if len(preview) > max {
preview = preview[:max]
}
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
}
return nil
}
func statusFromError(err error) int {
if err == nil {
return 0

View File

@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
authIDs []string
}
type invalidJSONStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 1)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
executor := &invalidJSONStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
gotErr := false
for msg := range errChan {
if msg == nil {
continue
}
if msg.StatusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
}
if msg.Error == nil {
t.Fatalf("expected error")
}
gotErr = true
}
if !gotErr {
t.Fatalf("expected terminal error")
}
}

View File

@@ -418,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))

View File

@@ -0,0 +1,43 @@
package openai
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte)
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)
}
if strings.Contains(body, `"error":{`) {
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
}
}

View File

@@ -0,0 +1,119 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
type openAIResponsesStreamErrorChunk struct {
Type string `json:"type"`
Code string `json:"code"`
Message string `json:"message"`
SequenceNumber int `json:"sequence_number"`
}
func openAIResponsesStreamErrorCode(status int) string {
switch status {
case http.StatusUnauthorized:
return "invalid_api_key"
case http.StatusForbidden:
return "insufficient_quota"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "model_not_found"
case http.StatusRequestTimeout:
return "request_timeout"
default:
if status >= http.StatusInternalServerError {
return "internal_server_error"
}
if status >= http.StatusBadRequest {
return "invalid_request_error"
}
return "unknown_error"
}
}
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
//
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
// of chunks that requires a top-level `type` field.
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if sequenceNumber < 0 {
sequenceNumber = 0
}
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
code := openAIResponsesStreamErrorCode(status)
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := payload["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
sequenceNumber = int(v)
}
}
if e, ok := payload["error"].(map[string]any); ok {
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := e["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
}
}
}
if strings.TrimSpace(code) == "" {
code = "unknown_error"
}
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: code,
Message: message,
SequenceNumber: sequenceNumber,
})
if err == nil {
return data
}
// Extremely defensive fallback.
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: "internal_server_error",
Message: message,
SequenceNumber: sequenceNumber,
})
if len(data) > 0 {
return data
}
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
}

View File

@@ -0,0 +1,48 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "unexpected EOF" {
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
}
if payload["sequence_number"] != float64(0) {
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
}
}
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(
http.StatusInternalServerError,
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
0,
)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "oops" {
t.Fatalf("message = %v, want %q", payload["message"], "oops")
}
}

Some files were not shown because too many files have changed in this diff Show More