Compare commits

...

651 Commits

Author SHA1 Message Date
Luis Pater
b43610159f Merge branch 'router-for-me:main' into main 2026-02-01 05:30:36 +08:00
Luis Pater
6d8609e457 feat(config): add payload filter rules to remove JSON paths
Introduce `Filter` rules in the payload configuration to remove specified JSON paths from the payload. Update related helper functions and add examples to `config.example.yaml`.
2026-02-01 05:29:41 +08:00
Luis Pater
dcd0ae7467 Merge branch 'router-for-me:main' into main 2026-01-31 23:49:45 +08:00
Luis Pater
d216adeffc Fixed: #1372 #1366
fix(caching): ensure unique cache_control injection using count validation
2026-01-31 23:48:50 +08:00
hkfires
bb09708c02 fix(config): add codex instructions enabled change to config change details 2026-01-31 22:44:25 +08:00
hkfires
1150d972a1 fix(misc): update opencode instructions 2026-01-31 22:28:30 +08:00
hkfires
2854e04bbb fix(misc): update user agent string for opencode 2026-01-31 11:23:08 +08:00
Luis Pater
f3fd7a9fbd Merge branch 'router-for-me:main' into main 2026-01-31 04:04:52 +08:00
Luis Pater
f99cddf97f fix(translator): handle stop_reason and MAX_TOKENS for Claude responses 2026-01-31 04:03:01 +08:00
Luis Pater
0606a7762c Merge branch 'router-for-me:main' into main 2026-01-31 03:14:11 +08:00
Luis Pater
f887f9985d Merge pull request #1248 from shekohex/feat/responses-compact
feat(openai): add responses/compact support
2026-01-31 03:12:55 +08:00
Luis Pater
550da0cee8 fix(translator): include token usage in message_delta for Claude responses 2026-01-31 02:55:27 +08:00
Luis Pater
e662c020a9 Merge branch 'router-for-me:main' into main 2026-01-31 01:43:48 +08:00
Luis Pater
7ff3936efe fix(caching): ensure prompt-caching beta is always appended and add multi-turn cache control tests 2026-01-31 01:42:58 +08:00
Luis Pater
29594086c0 chore(docs): add links to mainline repository in README files 2026-01-31 01:24:29 +08:00
Luis Pater
b0433c9f2a chore(docs): update image source and config URLs in README files 2026-01-31 01:22:28 +08:00
Luis Pater
b1204b1423 Merge branch 'router-for-me:main' into main 2026-01-31 01:15:14 +08:00
Luis Pater
43ca112fff Merge pull request #157 from crossly/bugfix/kiro-token-extraction-from-metadata
fix(kiro): Support token extraction from Metadata for file-based authentication
2026-01-31 01:14:28 +08:00
Luis Pater
24cf7fa6a2 Merge pull request #156 from taetaetae/fix/kiro-api-region
fix(kiro): Do not use OIDC region for API endpoint
2026-01-31 01:13:47 +08:00
Luis Pater
bf66bcad86 Merge pull request #155 from PancakeZik/feature/use-q-endpoint
feat(kiro): switch to Amazon Q endpoint as primary
2026-01-31 01:13:15 +08:00
Luis Pater
f36a5f5654 Merge pull request #1294 from Darley-Wey/fix/claude2gemini
fix: skip empty text parts and messages to avoid Gemini API error
2026-01-31 01:05:41 +08:00
Luis Pater
c1facdff67 Merge pull request #1295 from SchneeMart/feature/claude-caching
feat(caching): implement Claude prompt caching with multi-turn support
2026-01-31 01:04:19 +08:00
ricky
0263f9d35b Restore README files 2026-01-31 00:21:17 +08:00
ricky
101498e737 Fix: Support token extraction from Metadata for file-based Kiro auth
- Modified extractKiroTokenData to support both Attributes and Metadata sources
- Fixes issue where JSON file-based tokens were not being read correctly
- FileSynthesizer stores tokens in Metadata, ConfigSynthesizer uses Attributes
- Now checks Attributes first (config.yaml), falls back to Metadata (JSON files)
- Ensures dynamic model fetching works for all Kiro authentication methods
- Prevents fallback to static model list that incorrectly includes opus for free accounts
2026-01-31 00:15:35 +08:00
Luis Pater
4ee46bc9f2 Merge pull request #1311 from router-for-me/fix/gemini-schema
fix(gemini): Removes unsupported extension fields
2026-01-30 23:55:56 +08:00
Luis Pater
c3e94a8277 Merge pull request #1317 from yinkev/feat/gemini-tools-passthrough
feat(translator): add code_execution and url_context tool passthrough
2026-01-30 23:46:44 +08:00
taetaetae
fafef32b9e fix(kiro): Do not use OIDC region for API endpoint
Kiro API endpoints only exist in us-east-1, but OIDC region can vary
by Enterprise user location (e.g., ap-northeast-2 for Korean users).

Previously, when ProfileARN was not available, the code fell back to
using OIDC region for API calls, causing DNS resolution failures:

  lookup codewhisperer.ap-northeast-2.amazonaws.com: no such host

This fix removes the OIDC region fallback for API endpoints.
The region priority is now:
1. api_region (explicit override)
2. ProfileARN region
3. us-east-1 (default)

Fixes: Issue #253 (200-400x slower response times due to DNS failures)
2026-01-31 00:05:53 +09:00
Joao
1e764de0a8 feat(kiro): switch to Amazon Q endpoint as primary
Switch from CodeWhisperer endpoint to Amazon Q endpoint for all auth types:

- Use q.{region}.amazonaws.com/generateAssistantResponse as primary endpoint
- Works universally across all AWS regions (CodeWhisperer only exists in us-east-1)
- Use application/json Content-Type instead of application/x-amz-json-1.0
- Remove X-Amz-Target header for Q endpoint (not required)
- Add x-amzn-kiro-agent-mode: vibe header
- Add x-amzn-codewhisperer-optout: true header
- Keep CodeWhisperer endpoint as fallback for compatibility

This change aligns with Amazon's consolidation of services under the Q branding
and provides better multi-region support for Enterprise/IDC users.
2026-01-30 13:50:19 +00:00
Luis Pater
b3b8d71dfc Merge pull request #154 from router-for-me/plus
v6.7.32
2026-01-30 21:34:38 +08:00
Luis Pater
ca29c42805 Merge branch 'main' into plus 2026-01-30 21:34:30 +08:00
Luis Pater
fcefa2c820 Merge pull request #152 from taetaetae/feat/kiro-dynamic-region-support
feat(kiro): Add dynamic region support for API endpoints
2026-01-30 21:30:04 +08:00
Luis Pater
6b6d030ed3 feat(auth): add custom HTTP client with utls for Claude API authentication
Introduce a custom HTTP client utilizing utls with Firefox TLS fingerprinting to bypass Cloudflare fingerprinting on Anthropic domains. Includes support for proxy configuration and enhanced connection management for HTTP/2.
2026-01-30 21:29:41 +08:00
Luis Pater
fd5b669c87 Merge pull request #150 from PancakeZik/fix/write-tool-truncation-handling
fix: handle Write tool truncation when content exceeds API limits
2026-01-30 21:15:31 +08:00
Luis Pater
30d832c9b1 Merge pull request #144 from woopencri/main
fix: handle zero output_tokens for kiro non-streaming requests
2026-01-30 21:06:20 +08:00
Luis Pater
2448691136 Merge pull request #143 from CheesesNguyen/fix/kiro-refresh-token
fix: refresh token for kiro enterprise account
2026-01-30 21:05:00 +08:00
taetaetae
e7cd7b5243 fix: Support separate OIDC and API regions via ProfileARN extraction
Address @Xm798's feedback: OIDC region may differ from API region in some
Enterprise setups (e.g., OIDC in us-east-2, API in us-east-1).

Region priority (highest to lowest):
1. api_region - explicit override for API endpoint region
2. ProfileARN - extract region from arn:aws:service:REGION:account:resource
3. region - OIDC/Identity region (fallback)
4. us-east-1 - default

Changes:
- Add extractRegionFromProfileARN() to parse region from ARN
- Update getKiroEndpointConfigs() with 4-level region priority
- Add regionSource logging for debugging
2026-01-30 21:52:02 +09:00
Luis Pater
33f89a2609 Merge pull request #140 from janckerchen/fix/github-copilot-logging
fix: support github-copilot provider in AccountInfo logging
2026-01-30 20:51:50 +08:00
Luis Pater
403a731e22 Merge pull request #139 from janckerchen/fix/github-copilot-vision-header
fix: add Copilot-Vision-Request header for vision content
2026-01-30 20:51:18 +08:00
Luis Pater
3631fab7e2 Merge pull request #153 from router-for-me/plus
v6.7.31
2026-01-30 20:46:42 +08:00
Luis Pater
b3d292a5f9 Merge branch 'main' into plus 2026-01-30 20:45:33 +08:00
taetaetae
9293c685e0 fix: Correct Amazon Q endpoint URL path
Revert the Amazon Q endpoint path to root '/' instead of '/generateAssistantResponse'.

The '/generateAssistantResponse' path is only for CodeWhisperer endpoint with
'GenerateAssistantResponse' target. Amazon Q endpoint uses 'SendMessage' target
which requires the root path.

Thanks to @gemini-code-assist for catching this copy-paste error.
2026-01-30 16:30:03 +09:00
taetaetae
38094a2339 feat(kiro): Add dynamic region support for API endpoints
## Problem
- Kiro API endpoints were hardcoded to us-east-1 region
- Enterprise users in other regions (e.g., ap-northeast-2) experienced
  significant latency (200-400x slower) due to cross-region API calls
- This is the API endpoint counterpart to quotio PR #241 which fixed
  token refresh endpoints

## Solution
- Add buildKiroEndpointConfigs(region) function for dynamic endpoint generation
- Extract region from auth.Metadata["region"] field
- Fallback to us-east-1 for backward compatibility
- Use case-insensitive authMethod comparison (consistent with quotio PR #252)

## Changes
- Add kiroDefaultRegion constant
- Convert hardcoded endpoint URLs to dynamic fmt.Sprintf with region
- Update getKiroEndpointConfigs to extract and use region from auth
- Fix isIDCAuth to use case-insensitive comparison

## Testing
- Backward compatible: defaults to us-east-1 when no region specified
- Enterprise users can now use their local region endpoints

Related:
- quotio PR #241: Dynamic region for token refresh (merged)
- quotio PR #252: authMethod case-insensitive fix
- quotio Issue #253: Performance issue report
2026-01-30 16:25:32 +09:00
kyinhub
538039f583 feat(translator): add code_execution and url_context tool passthrough
Add support for Gemini's code_execution and url_context tools in the
request translators, enabling:

- Agentic Vision: Image analysis with Python code execution for
  bounding boxes, annotations, and visual reasoning
- URL Context: Live web page content fetching and analysis

Tools are passed through using the same pattern as google_search:
- code_execution: {} -> codeExecution: {}
- url_context: {} -> urlContext: {}

Tested with Gemini 3 Flash Preview agentic vision successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 21:14:52 -08:00
이대희
ca796510e9 refactor(gemini): optimize removeExtensionFields with post-order traversal and DeleteBytes
Amp-Thread-ID: https://ampcode.com/threads/T-019c0d09-330d-7399-b794-652b94847df1
Co-authored-by: Amp <amp@ampcode.com>
2026-01-30 13:02:58 +09:00
이대희
d0d66cdcb7 fix(gemini): Removes unsupported extension fields
Removes x-* extension fields from JSON schemas to ensure compatibility with the Gemini API.

These fields, while valid in OpenAPI/JSON Schema, are not recognized by the Gemini API and can cause issues.
The change recursively walks the schema, identifies these extension fields, and removes them, except when they define properties.

Amp-Thread-ID: https://ampcode.com/threads/T-019c0cd1-9e59-722b-83f0-e0582aba6914
Co-authored-by: Amp <amp@ampcode.com>
2026-01-30 12:31:26 +09:00
Luis Pater
d7d54fa2cc feat(ci): add cleanup step for temporary Docker tags in workflow 2026-01-30 09:15:00 +08:00
Luis Pater
31649325f0 feat(ci): add multi-arch Docker builds and manifest creation to workflow 2026-01-30 07:26:36 +08:00
Martin Schneeweiss
3a43ecb19b feat(caching): implement Claude prompt caching with multi-turn support
- Add ensureCacheControl() to auto-inject cache breakpoints
- Cache tools (last tool), system (last element), and messages (2nd-to-last user turn)
- Add prompt-caching-2024-07-31 beta header
- Return original payload on sjson error to prevent corruption
- Include verification test for caching logic

Enables up to 90% cost reduction on cached tokens.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 22:59:33 +01:00
Luis Pater
a709e5a12d fix(config): ensure empty mapping persists for oauth-model-alias deletions #1305 2026-01-30 04:17:56 +08:00
Luis Pater
f0ac77197b Merge pull request #1300 from sususu98/feat/log-api-response-timestamp
fix(logging): add API response timestamp and fix request timestamp timing
2026-01-30 03:27:17 +08:00
Luis Pater
da0bbf2a3f Merge pull request #1298 from sususu98/fix/restore-usageMetadata-in-gemini-translator
fix(translator): restore usageMetadata in Gemini responses from Antigravity
2026-01-30 02:59:41 +08:00
sususu98
295f34d7f0 fix(logging): capture streaming TTFB on first chunk and make timestamps required
- Add firstChunkTimestamp field to ResponseWriterWrapper for sync capture
- Capture TTFB in Write() and WriteString() before async channel send
- Add SetFirstChunkTimestamp() to StreamingLogWriter interface
- Make requestTimestamp/apiResponseTimestamp required in LogRequest()
- Remove timestamp capture from WriteAPIResponse() (now via setter)
- Fix Gemini handler to set API_RESPONSE_TIMESTAMP before writing response

This ensures accurate TTFB measurement for all streaming API formats
(OpenAI, Gemini, Claude) by capturing timestamp synchronously when
the first response chunk arrives, not when the stream finalizes.
2026-01-29 22:32:24 +08:00
sususu98
c41ce77eea fix(logging): add API response timestamp and fix request timestamp timing
Previously:
- REQUEST INFO timestamp was captured at log write time (not request arrival)
- API RESPONSE had NO timestamp at all

This fix:
- Captures REQUEST INFO timestamp when request first arrives
- Adds API RESPONSE timestamp when upstream response arrives

Changes:
- Add Timestamp field to RequestInfo, set at middleware initialization
- Set API_RESPONSE_TIMESTAMP in appendAPIResponse() and gemini handler
- Pass timestamps through logging chain to writeNonStreamingLog()
- Add timestamp output to API RESPONSE section

This enables accurate measurement of backend response latency in error logs.
2026-01-29 22:22:18 +08:00
Joao
876b86ff91 fix: handle json.Marshal error for truncated write bash input 2026-01-29 13:07:20 +00:00
Joao
acdfa1c87f fix: handle Write tool truncation when content exceeds API limits
When the Kiro/AWS CodeWhisperer API receives a Write tool request with content
that exceeds transmission limits, it truncates the tool input. This can result in:
- Empty input buffer (no input transmitted at all)
- Missing 'content' field in the parsed JSON
- Incomplete JSON that fails to parse

This fix detects these truncation scenarios and converts them to Bash tool calls
that echo an error message. This allows Claude Code to execute the Bash command,
see the error output, and the agent can then retry with smaller chunks.

Changes:
- kiro_claude_tools.go: Detect three truncation scenarios in ProcessToolUseEvent:
  1. Empty input buffer (no input transmitted)
  2. JSON parse failure with file_path but no content field
  3. Successfully parsed JSON missing content field
  When detected, emit a special '__truncated_write__' marker tool use

- kiro_executor.go: Handle '__truncated_write__' markers in streamToChannel:
  1. Extract file_path from the marker for context
  2. Create a Bash tool_use that echoes an error message
  3. Include retry guidance (700-line chunks recommended)
  4. Set hasToolUses=true to ensure stop_reason='tool_use' for agent continuation

This ensures the agent continues and can retry with smaller file chunks instead
of failing silently or showing errors to the user.
2026-01-29 12:22:55 +00:00
Luis Pater
4eb1e6093f feat(handlers): add test to verify no retries after partial stream response
Introduce `TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte` to validate that stream executions do not retry after receiving partial responses. Implement `payloadThenErrorStreamExecutor` for test coverage of this behavior.
2026-01-29 17:30:48 +08:00
Luis Pater
189a066807 Merge pull request #1296 from router-for-me/log
fix(api): update amp module only on config changes
2026-01-29 17:27:52 +08:00
hkfires
d0bada7a43 fix(config): prune oauth-model-alias when preserving config 2026-01-29 14:06:52 +08:00
sususu98
9dc0e6d08b fix(translator): restore usageMetadata in Gemini responses from Antigravity
When using Gemini API format with Antigravity backend, the executor
renames usageMetadata to cpaUsageMetadata in non-terminal chunks.
The Gemini translator was returning this internal field name directly
to clients instead of the standard usageMetadata field.

Add restoreUsageMetadata() to rename cpaUsageMetadata back to
usageMetadata before returning responses to clients.
2026-01-29 11:16:00 +08:00
hkfires
8510fc313e fix(api): update amp module only on config changes 2026-01-29 09:28:49 +08:00
Darley
2666708c30 fix: skip empty text parts and messages to avoid Gemini API error
When Claude API sends an assistant message with empty text content like:
{"role":"assistant","content":[{"type":"text","text":""}]}
The translator was creating a part object {} with no data field,
causing Gemini API to return error:
"required oneof field 'data' must have one initialized field"
This fix:
1. Skips empty text parts (text="") during translation
2. Skips entire messages when their parts array becomes empty
This ensures compatibility when clients send empty assistant messages
in their conversation history.
2026-01-29 04:13:07 +08:00
woopencri
f2b0ce13d9 fix: handle zero output_tokens for kiro non-streaming requests 2026-01-28 16:27:34 +08:00
CheesesNguyen
b8652b7387 feat: normalize authentication method to lowercase for case-insensitive matching during token refresh and introduce new CLIProxyAPIPlus component. 2026-01-28 14:54:58 +07:00
CheesesNguyen
b18b2ebe9f fix: Implement graceful token refresh degradation and enhance IDC SSO support with device registration loading for Kiro. 2026-01-28 14:47:04 +07:00
Luis Pater
9e5b1d24e8 Merge pull request #1276 from router-for-me/thinking
feat(thinking): enable thinking toggle for qwen3 and deepseek models
2026-01-28 11:16:54 +08:00
Luis Pater
a7dae6ad52 Merge remote-tracking branch 'origin/dev' into dev 2026-01-28 10:59:00 +08:00
Luis Pater
e93e05ae25 refactor: consolidate channel send logic with context-safe handlers
Optimize channel operations by introducing reusable context-aware send functions (`send` and `sendErr`) across `wsrelay`, `handlers`, and `cliproxy`. Ensure graceful handling of canceled contexts during stream operations.
2026-01-28 10:58:35 +08:00
hkfires
c8c27325dc feat(thinking): enable thinking toggle for qwen3 and deepseek models
Fix #1245
2026-01-28 09:54:05 +08:00
hkfires
c3b6f3918c chore(git): stop ignoring .idea and data directories 2026-01-28 09:52:44 +08:00
Luis Pater
bbb55a8ab4 Merge pull request #1170 from BianBianY/main
feat: optimization enable/disable auth files
2026-01-28 09:34:35 +08:00
Shady Khalifa
04b2290927 fix(codex): avoid empty prompt_cache_key 2026-01-27 19:06:42 +02:00
Shady Khalifa
53920b0399 fix(openai): drop stream for responses/compact 2026-01-27 18:27:34 +02:00
cybit
58290760a9 fix: support github-copilot provider in AccountInfo logging
Changed the provider matching logic in AccountInfo() method to use
prefix matching instead of exact matching. This allows both 'github'
(Kiro OAuth) and 'github-copilot' providers to be correctly identified
as OAuth providers, enabling proper debug logging output.

Before: Use OAuth logs were missing for github-copilot requests
After: Logs show "Use OAuth provider=github-copilot auth_file=..."

Co-Authored-By: Claude (claude-sonnet-4.5) <noreply@anthropic.com>
2026-01-27 21:56:00 +08:00
Luis Pater
8f522eed43 Merge pull request #138 from router-for-me/plus
v6.7.26
2026-01-27 20:40:12 +08:00
Luis Pater
3dc001a9d2 Merge branch 'main' into plus 2026-01-27 20:39:59 +08:00
Luis Pater
ee54ee8825 Merge pull request #137 from geen02/fix/idc-auth-method-case-sensitivity
fix: case-insensitive auth_method comparison for IDC tokens
2026-01-27 20:38:03 +08:00
Luis Pater
2395b7a180 Merge pull request #135 from gogoing1024/main
支持多个idc登录凭证保存
2026-01-27 20:36:56 +08:00
Luis Pater
7583193c2a Merge pull request #1257 from router-for-me/model
feat(api): add management model definitions endpoint
2026-01-27 20:32:04 +08:00
hkfires
7cc3bd4ba0 chore(deps): mark golang.org/x/text as indirect 2026-01-27 19:19:52 +08:00
hkfires
88a0f095e8 chore(registry): disable gemini 2.5 flash image preview model 2026-01-27 18:33:13 +08:00
hkfires
c65f64dce0 chore(registry): comment out rev19-uic3-1p model config 2026-01-27 18:33:13 +08:00
hkfires
d18cd217e1 feat(api): add management model definitions endpoint 2026-01-27 18:33:12 +08:00
Luis Pater
ba4a1ab433 Merge pull request #1261 from Darley-Wey/fix/gemini_scheme
fix(gemini): force type to string for enum fields to fix Antigravity Gemini API error
2026-01-27 17:02:25 +08:00
Darley
decddb521e fix(gemini): force type to string for enum fields to fix Antigravity Gemini API error (Relates to #1260) 2026-01-27 11:14:08 +03:30
cybit
33ab3a99f0 fix: add Copilot-Vision-Request header for vision requests
**Problem:**
GitHub Copilot API returns 400 error "missing required Copilot-Vision-Request
header for vision requests" when requests contain image content blocks, even
though the requests are valid Claude API calls.

**Root Cause:**
The GitHub Copilot executor was not detecting vision content in requests and
did not add the required `Copilot-Vision-Request: true` header.

**Solution:**
- Added `detectVisionContent()` function to check for image_url/image content blocks
- Automatically add `Copilot-Vision-Request: true` header when vision content is detected
- Applied fix to both `Execute()` and `ExecuteStream()` methods

**Testing:**
- Tested with Claude Code IDE requests containing code context screenshots
- Vision requests now succeed instead of failing with 400 errors
- Non-vision requests remain unchanged

Fixes issue where GitHub Copilot executor fails all vision-enabled requests,
causing unnecessary fallback to other providers and 0% utilization.

Co-Authored-By: Claude (claude-sonnet-4.5) <noreply@anthropic.com>
2026-01-27 15:13:54 +08:00
jyy
de6b1ada5d fix: case-insensitive auth_method comparison for IDC tokens
The background refresher was skipping token files with auth_method values
like 'IdC' or 'IDC' because the comparison was case-sensitive and only
matched lowercase 'idc'.

This fix normalizes the auth_method to lowercase before comparison in:
- token_repository.go: readTokenFile() when filtering tokens to refresh
- background_refresh.go: refreshSingle() when selecting refresh method

Fixes the issue where 'IdC' != 'idc' caused tokens to be skipped entirely.
2026-01-27 13:39:38 +09:00
gogoing1024
e08f48c7a1 Merge branch 'router-for-me:main' into main 2026-01-27 09:23:36 +08:00
Luis Pater
851712a49e Merge pull request #132 from ClubWeGo/codex/resolve-issue-#131
Resolve Issue #131
2026-01-26 23:36:16 +08:00
Luis Pater
9e34323a40 Merge branch 'router-for-me:main' into main 2026-01-26 23:35:07 +08:00
Shady Khalifa
95096bc3fc feat(openai): add responses/compact support 2026-01-26 16:36:01 +02:00
Luis Pater
70897247b2 feat(auth): add support for request_retry and disable_cooling overrides
Implement `request_retry` and `disable_cooling` metadata overrides for authentication management. Update retry and cooling logic accordingly across `Manager`, Antigravity executor, and file synthesizer. Add tests to validate new behaviors.
2026-01-26 21:59:08 +08:00
Luis Pater
9c341f5aa5 feat(auth): add skip persistence context key for file watcher events
Introduce `WithSkipPersist` to disable persistence during Manager Update/Register calls, preventing write-back loops caused by redundant file writes. Add corresponding tests and integrate with existing file store and conductor logic.
2026-01-26 18:20:19 +08:00
yuechenglong.5
f74a688fb9 refactor(auth): extract token filename generation into unified function
Add ExtractIDCIdentifier and GenerateTokenFileName functions to centralize
token filename generation logic. This improves code maintainability by:

- Extracting IDC identifier from startUrl for unique token file naming
- Supporting priority-based filename generation (email > startUrl > authMethod)
- Removing duplicate filename generation code from oauth_web.go
- Adding comprehensive unit tests for the new functions
2026-01-26 13:54:32 +08:00
Darley
e3e741d0be Default Claude tool input schema 2026-01-26 09:15:38 +08:00
Darley
7c7c5fd967 Fix Kiro tool schema defaults 2026-01-26 08:27:53 +08:00
Luis Pater
fe8c7a62aa Merge branch 'router-for-me:main' into main 2026-01-26 06:23:41 +08:00
Luis Pater
2af4a8dc12 refactor(runtime): implement retry logic for Antigravity executor with improved error handling and capacity management 2026-01-26 06:22:46 +08:00
Luis Pater
0f53b952b2 Merge pull request #1225 from router-for-me/log
Add request_id to error logs and extract error messages
2026-01-25 22:08:46 +08:00
Luis Pater
7b2ae7377a chore(auth): add net/url import to auth_files.go for URL handling 2026-01-25 21:53:20 +08:00
Luis Pater
c2ab288c7d Merge pull request #130 from router-for-me/plus
v6.7.22
2026-01-25 21:51:20 +08:00
Luis Pater
dbb433fcf8 Merge branch 'main' into plus 2026-01-25 21:51:02 +08:00
Luis Pater
2abf00b5a6 Merge pull request #126 from jellyfish-p/main
feat(kiro): 添加用于令牌额度查询的api-call兼容
2026-01-25 21:49:07 +08:00
Luis Pater
275839e5c9 Merge pull request #124 from gogoing1024/main
fix(kiro): always attempt token refresh on 401 before checking retry …
2026-01-25 21:48:03 +08:00
hkfires
f30ffd5f5e feat(executor): add request_id to error logs
Extract error.message from JSON error responses when summarizing error bodies for debug logs
2026-01-25 21:31:46 +08:00
Luis Pater
bc9a24d705 docs(readme): reposition CPA-XXX Panel section for improved visibility 2026-01-25 18:58:32 +08:00
Luis Pater
2c879f13ef Merge pull request #1216 from ferretgeek/add-cpa-xxx-panel
docs: 新增 CPA-XXX 社区面板项目
2026-01-25 18:57:32 +08:00
Gemini
07b4a08979 docs: translate CPA-XXX description to English 2026-01-25 18:00:28 +08:00
jellyfish-p
497339f055 feat(kiro): 添加用于令牌额度查询的api-call兼容 2026-01-25 11:36:52 +08:00
Gemini
7f612bb069 docs: add CPA-XXX panel to community list
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2026-01-25 10:45:51 +08:00
hkfires
5743b78694 test(claude): update expectations for system message handling 2026-01-25 08:31:29 +08:00
Luis Pater
2e6a2b655c Merge pull request #1132 from XYenon/fix/gemini-models-displayname-override
fix(gemini): preserve displayName and description in models list
2026-01-25 03:40:04 +08:00
Luis Pater
cb47ac21bf Merge pull request #1179 from mallendeo/main
fix(claude): skip built-in tools in OAuth tool prefix
2026-01-25 03:31:58 +08:00
Luis Pater
a1394b4596 Merge pull request #1183 from Darley-Wey/fix/api-align
fix(api): enhance ClaudeModels response to align with api.anthropic.com
2026-01-25 03:30:14 +08:00
Luis Pater
9e97948f03 Merge pull request #1185 from router-for-me/auth
Refactor authentication handling for Antigravity, Claude, Codex, and Gemini
2026-01-25 03:28:53 +08:00
yuechenglong.5
8f780e7280 fix(kiro): always attempt token refresh on 401 before checking retry count
Refactor 401 error handling in both executeWithRetry and
executeStreamWithRetry to always attempt token refresh regardless of
remaining retry attempts. Previously, token refresh was only attempted
when retries remained, which could leave valid refreshed tokens unused.

Also add auth directory resolution in RefreshManager.Initialize to
properly resolve the base directory path before creating the token
repository.
2026-01-24 20:02:09 +08:00
Yang Bian
f7bfa8a05c Merge branch 'upstream-main' 2026-01-24 16:28:08 +08:00
Darley
46c6fb1e7a fix(api): enhance ClaudeModels response to align with api.anthropic.com 2026-01-24 04:41:08 +03:30
hkfires
9f9fec5d4c fix(auth): improve antigravity token exchange errors 2026-01-24 09:04:15 +08:00
hkfires
e95be10485 fix(auth): validate antigravity token userinfo email 2026-01-24 08:33:52 +08:00
hkfires
f3d58fa0ce fix(auth): correct antigravity oauth redirect and expiry 2026-01-24 08:33:52 +08:00
hkfires
8c0eaa1f71 refactor(auth): export Gemini constants and use in handler 2026-01-24 08:33:52 +08:00
hkfires
405df58f72 refactor(auth): export Codex constants and slim down handler 2026-01-24 08:33:52 +08:00
hkfires
e7f13aa008 refactor(api): slim down RequestAnthropicToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
7cb6a9b89a refactor(auth): export Claude OAuth constants for reuse 2026-01-24 08:33:51 +08:00
hkfires
9aa5344c29 refactor(api): slim down RequestAntigravityToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
8ba0ebbd2a refactor(sdk): slim down Antigravity authenticator to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
c65407ab9f refactor(auth): extract Antigravity OAuth constants to internal/auth 2026-01-24 08:33:51 +08:00
hkfires
9e59685212 refactor(auth): implement Antigravity AuthService in internal/auth 2026-01-24 08:33:51 +08:00
hkfires
4a4dfaa910 refactor(auth): replace sanitizeAntigravityFileName with antigravity.CredentialFileName 2026-01-24 08:33:51 +08:00
Luis Pater
0d6ecb0191 Fixed: #1077
refactor(translator): improve tools handling by separating functionDeclarations and googleSearch nodes
2026-01-24 05:51:11 +08:00
Mauricio Allende
f16461bfe7 fix(claude): skip built-in tools in OAuth tool prefix 2026-01-23 21:29:39 +00:00
Luis Pater
9fccc86b71 fix(executor): include requested model in payload configuration 2026-01-24 05:06:02 +08:00
Luis Pater
74683560a7 chore(deps): update go.mod to add golang.org/x/sync and golang.org/x/text 2026-01-24 05:04:09 +08:00
Luis Pater
1e4f9dd438 Merge pull request #123 from router-for-me/plus
v6.7.20
2026-01-24 05:02:41 +08:00
Luis Pater
b9ff916494 Merge branch 'main' into plus 2026-01-24 05:02:32 +08:00
Luis Pater
9bf4a0cad2 Merge pull request #120 from Xm798/fix/kiro-auth-method-case
fix(auth): normalize Kiro authMethod to lowercase on token import
2026-01-24 04:58:50 +08:00
Luis Pater
c32e2a8196 fix(auth): handle context cancellation in executor methods 2026-01-24 04:56:55 +08:00
Luis Pater
873d41582f Merge pull request #1125 from NightHammer1000/dev
Filter out Top_P when Temp is set on Claude
2026-01-24 02:03:33 +08:00
Luis Pater
6fb7d85558 Merge pull request #1137 from augustVino/fix/remove_empty_systemmsg
fix(translator): ensure system message is only added if it contains c…
2026-01-24 02:02:18 +08:00
hkfires
d5e3e32d58 fix(auth): normalize plan type filenames to lowercase 2026-01-23 20:13:09 +08:00
Chén Mù
f353a54555 Merge pull request #1171 from router-for-me/auth
refactor(auth): remove unused provider execution helpers
2026-01-23 19:43:42 +08:00
Chén Mù
1d6e2e751d Merge pull request #1140 from sxjeru/main
fix(auth): handle quota cooldown in retry logic for transient errors
2026-01-23 19:43:17 +08:00
hkfires
cc50b63422 refactor(auth): remove unused provider execution helpers 2026-01-23 19:12:55 +08:00
Luis Pater
15ae83a15b Merge pull request #1169 from router-for-me/payload
feat(executor): apply payload rules using requested model
2026-01-23 18:41:31 +08:00
hkfires
81b369aed9 fix(auth): include requested model in executor metadata 2026-01-23 18:30:08 +08:00
Yang Bian
c8620d1633 feat: optimization enable/disable auth files 2026-01-23 18:03:09 +08:00
hkfires
ecc850bfb7 feat(executor): apply payload rules using requested model 2026-01-23 16:38:41 +08:00
Chén Mù
19b4ef33e0 Merge pull request #1102 from aldinokemal/main
feat(management): add PATCH endpoint to enable/disable auth files
2026-01-23 09:05:24 +08:00
hkfires
7ca045d8b9 fix(executor): adjust model-specific request payload 2026-01-22 20:28:08 +08:00
Cyrus
25b9df478c fix(auth): normalize authMethod to lowercase on Kiro token import
- Add strings.ToLower() normalization in LoadKiroIDEToken()
- Add same normalization in LoadKiroTokenFromPath()
- Fixes issue where Kiro IDE exports "IdC" but code expects "idc"
2026-01-22 19:54:48 +08:00
hkfires
abfca6aab2 refactor(util): reorder gemini schema cleaner helpers 2026-01-22 18:38:48 +08:00
Chén Mù
3c71c075db Merge pull request #1131 from sowar1987/fix/gemini-malformed-function-call
Fix Gemini tool calling for Antigravity (malformed_function_call)
2026-01-22 18:07:03 +08:00
sowar1987
9c2992bfb2 test: align signature cache tests with cache behavior
Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
269a1c5452 refactor: reuse placeholder reason description
Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
22ce65ac72 test: update signature cache tests
Revert gemini translator changes for scheme A

Co-Authored-By: Warp <agent@warp.dev>
2026-01-22 17:12:47 +08:00
sowar1987
a2f8f59192 Fix Gemini function-calling INVALID_ARGUMENT by relaxing Gemini tool validation and cleaning schema 2026-01-22 17:11:07 +08:00
XYenon
8c7c446f33 fix(gemini): preserve displayName and description in models list
Previously GeminiModels handler unconditionally overwrote displayName
and description with the model name, losing the original values defined
in model definitions (e.g., 'Gemini 3 Pro Preview').

Now only set these fields as fallback when they are missing or empty.
2026-01-22 15:19:27 +08:00
Luis Pater
51611c25d7 Merge branch 'router-for-me:main' into main 2026-01-21 22:12:28 +08:00
Luis Pater
eb1bbaa63b Merge pull request #119 from linlang781/main
支持Kiro sso idc
2026-01-21 22:11:58 +08:00
sxjeru
30a59168d7 fix(auth): handle quota cooldown in retry logic for transient errors 2026-01-21 21:48:23 +08:00
yuechenglong.5
4c8026ac3d chore(build): 更新 .gitignore 文件
- 添加 *.bak 文件扩展名到忽略列表
2026-01-21 21:38:47 +08:00
gogoing1024
8aeb4b7d54 Merge pull request #1 from gogoing1024/main
Merge pull request #1 from linlang781/main
2026-01-21 21:09:34 +08:00
gogoing1024
b2172cb047 Merge pull request #1 from linlang781/main
1
2026-01-21 21:07:24 +08:00
hkfires
c8884f5e25 refactor(translator): enhance signature handling in Claude and Gemini requests, streamline cache usage and remove unnecessary tests 2026-01-21 20:21:49 +08:00
Luis Pater
d9c6317c84 refactor(cache, translator): refine signature caching logic and tests, replace session-based logic with model group handling 2026-01-21 18:30:05 +08:00
Vino
d29ec95526 fix(translator): ensure system message is only added if it contains content 2026-01-21 16:45:50 +08:00
Luis Pater
ef4508dbc8 refactor(cache, translator): remove session ID from signature caching and clean up logic 2026-01-21 13:37:10 +08:00
Luis Pater
f775e46fe2 refactor(translator): remove session ID logic from signature caching and associated tests 2026-01-21 12:45:07 +08:00
Luis Pater
65ad5c0c9d refactor(cache): simplify signature caching by removing sessionID parameter 2026-01-21 12:38:05 +08:00
Luis Pater
88bf4e77ec fix(translator): update HasValidSignature to require modelName parameter for improved validation 2026-01-21 11:31:37 +08:00
yuechenglong.5
194f66ca9c feat(kiro): 添加后台令牌刷新通知机制
- 在 BackgroundRefresher 中添加 onTokenRefreshed 回调函数和并发安全锁
- 实现 WithOnTokenRefreshed 选项函数用于设置刷新成功回调
- 在 RefreshManager 中添加 SetOnTokenRefreshed 方法支持运行时更新回调
- 为 KiroExecutor 添加 reloadAuthFromFile 方法实现文件重新加载回退机制
- 在 Watcher 中实现 NotifyTokenRefreshed 方法处理刷新通知并更新内存Auth对象
- 通过 Service.GetWatcher 连接刷新器回调到 Watcher 通知链路
- 添加方案A和方案B双重保障解决后台刷新与内存对象时间差问题
2026-01-21 11:03:07 +08:00
Luis Pater
a4f8015caa test(logging): add unit tests for GinLogrusRecovery middleware panic handling 2026-01-21 10:57:27 +08:00
Luis Pater
ffd129909e Merge pull request #1130 from router-for-me/agty
fix(executor): only strip maxOutputTokens for non-claude models
2026-01-21 10:50:39 +08:00
hkfires
9332316383 fix(translator): preserve thinking blocks by skipping signature 2026-01-21 10:49:20 +08:00
hkfires
6dcbbf64c3 fix(executor): only strip maxOutputTokens for non-claude models 2026-01-21 10:49:20 +08:00
yuechenglong.5
c9aa1ff99d Merge remote-tracking branch 'origin/main'
# Conflicts:
#	internal/auth/kiro/oauth_web.go
2026-01-21 10:31:55 +08:00
Luis Pater
2ce3553612 feat(cache): handle gemini family in signature cache with fallback validator logic 2026-01-21 10:11:21 +08:00
Luis Pater
2e14f787d4 feat(translator): enhance ConvertGeminiRequestToAntigravity with model name and refine reasoning block handling 2026-01-21 08:31:23 +08:00
Luis Pater
523b41ccd2 test(responses): add comprehensive tests for SSE event ordering and response transformations 2026-01-21 07:08:59 +08:00
N1GHT
09970dc7af Accept Geminis Review Suggestion 2026-01-20 17:51:36 +01:00
N1GHT
d81abd401c Returned the Code Comment I trashed 2026-01-20 17:36:27 +01:00
N1GHT
a6cba25bc1 Small fix to filter out Top_P when Temperature is set on Claude to make requests go through 2026-01-20 17:34:26 +01:00
Luis Pater
c6fa1d0e67 Merge pull request #1117 from router-for-me/cache
fix(translator): enhance signature cache clearing logic and update test cases with model name
2026-01-20 23:18:48 +08:00
Luis Pater
ac56e1e88b Merge pull request #1116 from bexcodex/fix/antigravity
Fix antigravity malformed_function_call
2026-01-20 22:40:00 +08:00
781456868@qq.com
a9ee971e1c fix(kiro): improve auto-refresh and IDC auth file handling
Amp-Thread-ID: https://ampcode.com/threads/T-019bdb94-80e3-7302-be0f-a69937826d13
Co-authored-by: Amp <amp@ampcode.com>
2026-01-20 21:57:45 +08:00
781456868@qq.com
73cef3a25a Merge remote-tracking branch 'upstream/main' 2026-01-20 21:57:16 +08:00
hkfires
9b72ea9efa fix(translator): enhance signature cache clearing logic and update test cases with model name 2026-01-20 20:02:29 +08:00
bexcodex
9f364441e8 Fix antigravity malformed_function_call 2026-01-20 19:54:54 +08:00
Luis Pater
e49a1c07bf chore(translator): update cache functions to include model name parameter in tests 2026-01-20 18:36:51 +08:00
Luis Pater
5364a2471d fix(endpoint_compat): update GetModelInfo to include missing parameter for improved registry compatibility 2026-01-20 13:56:57 +08:00
Luis Pater
fef4fdb0eb Merge pull request #117 from router-for-me/plus
v6.7.15
2026-01-20 13:50:53 +08:00
Luis Pater
c2bf600a39 Merge branch 'main' into plus 2026-01-20 13:50:41 +08:00
Luis Pater
8d9f4edf9b feat(translator): unify model group references by introducing GetModelGroup helper function 2026-01-20 13:45:25 +08:00
Luis Pater
020e61d0da feat(translator): improve signature handling by associating with model name in cache functions 2026-01-20 13:31:36 +08:00
Luis Pater
6184c43319 Fixed: #1109
feat(translator): enhance session ID derivation with user_id parsing in Claude
2026-01-20 12:35:40 +08:00
Luis Pater
2cbe4a790c chore(translator): remove unnecessary whitespace in gemini_openai_response code 2026-01-20 11:47:33 +08:00
Luis Pater
68b3565d7b Merge branch 'main' into dev (PR #961) 2026-01-20 11:42:22 +08:00
Luis Pater
3f385a8572 feat(auth): add "antigravity" provider to ignored access_token fields in filestore 2026-01-20 11:38:31 +08:00
Luis Pater
9823dc35e1 feat(auth): hash account ID for improved uniqueness in credential filenames 2026-01-20 11:37:52 +08:00
Luis Pater
059bfee91b feat(auth): add hashed account ID to credential filenames for team plans 2026-01-20 11:36:29 +08:00
Luis Pater
7beaf0eaa2 Merge pull request #869 2026-01-20 11:16:53 +08:00
Luis Pater
1fef90ff58 Merge pull request #877 from zhiqing0205/main
feat(codex): include plan type in auth filename
2026-01-20 11:11:25 +08:00
Luis Pater
8447fd27a0 fix(login): remove emojis from interactive prompt messages 2026-01-20 11:09:56 +08:00
Luis Pater
7831cba9f6 refactor(claude): remove redundant system instructions check in Claude executor 2026-01-20 11:02:52 +08:00
Luis Pater
e02b2d58d5 Merge pull request #868 2026-01-20 10:57:24 +08:00
Luis Pater
28726632a9 Merge pull request #861 from umairimtiaz9/fix/gemini-cli-backend-project-id
fix(auth): use backend project ID for free tier Gemini CLI OAuth users
2026-01-20 10:32:17 +08:00
yuechenglong.5
0f63d973be Merge remote-tracking branch 'origin/main' 2026-01-20 10:20:03 +08:00
Luis Pater
3b26129c82 Merge pull request #1108 from router-for-me/modelinfo
feat(registry): support provider-specific model info lookup
2026-01-20 10:18:42 +08:00
Luis Pater
d4bb4e6624 refactor(antigravity): remove unused client signature handling in thinking objects 2026-01-20 10:17:55 +08:00
yuechenglong.5
fa2abd560a chore: cherry-pick 文档更新和删除测试文件
- docs: 添加 Kiro OAuth web 认证端点说明 (ace7c0c)
- chore: 删除包含敏感数据的测试文件 (8f06f6a)
- 保留本地修改: refresh_manager, token_repository 等
2026-01-20 10:17:39 +08:00
Luis Pater
0766c49f93 Merge pull request #994 from adrenjc/fix/cross-model-thinking-signature
fix(antigravity): prevent corrupted thought signature when switching models
2026-01-20 10:14:05 +08:00
Luis Pater
a7ffc77e3d Merge branch 'dev' into fix/cross-model-thinking-signature 2026-01-20 10:10:43 +08:00
hkfires
e641fde25c feat(registry): support provider-specific model info lookup 2026-01-20 10:01:17 +08:00
yuechenglong.5
564c2d763e Merge upstream/main (08779cc) - sync with original repo updates 2026-01-20 09:52:11 +08:00
Luis Pater
5717c7f2f4 Merge pull request #1103 from dinhkarate/feat/imagen
feat(vertex): add Imagen image generation model support
2026-01-20 07:11:18 +08:00
dinhkarate
8734d4cb90 feat(vertex): add Imagen image generation model support
Add support for Imagen 3.0 and 4.0 image generation models in Vertex AI:

- Add 5 Imagen model definitions (4.0, 4.0-ultra, 4.0-fast, 3.0, 3.0-fast)
- Implement :predict action routing for Imagen models
- Convert Imagen request/response format to match Gemini structure like gemini-3-pro-image
- Transform prompts to Imagen's instances/parameters format
- Convert base64 image responses to Gemini-compatible inline data
2026-01-20 01:26:37 +07:00
Aldino Kemal
2f6004d74a perf(management): optimize auth lookup in PatchAuthFileStatus
Use GetByID() for O(1) map lookup first, falling back to iteration
only for FileName matching. Consistent with pattern in disableAuth().
2026-01-19 20:05:37 +07:00
Luis Pater
08779cc8a8 Merge branch 'router-for-me:main' into main 2026-01-19 21:00:58 +08:00
Luis Pater
5baa753539 Merge pull request #1099 from router-for-me/claude
refactor(claude): move max_tokens constraint enforcement to Apply method
2026-01-19 20:55:59 +08:00
781456868@qq.com
92fb6b012a feat(kiro): add manual token refresh button to OAuth web UI
Amp-Thread-ID: https://ampcode.com/threads/T-019bd642-9806-75d8-9101-27812e0eb6ab
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:55:51 +08:00
Luis Pater
ead98e4bca Merge pull request #1101 from router-for-me/argy
fix(executor): stop rewriting thinkingLevel for gemini
2026-01-19 20:55:22 +08:00
Aldino Kemal
a1634909e8 feat(management): add PATCH endpoint to enable/disable auth files
Add new PATCH /v0/management/auth-files/status endpoint that allows
toggling the disabled state of auth files without deleting them.
This enables users to temporarily disable credentials from the
management UI.
2026-01-19 19:50:36 +07:00
781456868@qq.com
8f06f6a9ed chore: remove test files containing sensitive data
Amp-Thread-ID: https://ampcode.com/threads/T-019bd618-7e42-715a-960d-dd45425851e3
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:31:33 +08:00
781456868@qq.com
ace7c0ccb4 docs: add Kiro OAuth web authentication endpoint /v0/oauth/kiro 2026-01-19 20:28:40 +08:00
781456868@qq.com
f87fe0a0e8 feat: proactive token refresh 10 minutes before expiry
Amp-Thread-ID: https://ampcode.com/threads/T-019bd618-7e42-715a-960d-dd45425851e3
Co-authored-by: Amp <amp@ampcode.com>
2026-01-19 20:09:38 +08:00
781456868@qq.com
87edc6f35e Merge remote-tracking branch 'upstream/main' 2026-01-19 20:09:17 +08:00
hkfires
1d2fe55310 fix(executor): stop rewriting thinkingLevel for gemini 2026-01-19 19:49:39 +08:00
hkfires
c175821cc4 feat(registry): expand antigravity model config
Remove static Name mapping and add entries for claude-sonnet-4-5,
tab_flash_lite_preview, and gpt-oss-120b-medium configs
2026-01-19 19:32:00 +08:00
hkfires
239a28793c feat(claude): clamp thinking budget to max_tokens constraints 2026-01-19 16:32:20 +08:00
hkfires
c421d653e7 refactor(claude): move max_tokens constraint enforcement to Apply method 2026-01-19 15:50:35 +08:00
Luis Pater
2542c2920d Merge pull request #1096 from router-for-me/usage
feat(translator): report cached token usage in Claude output
2026-01-19 11:52:18 +08:00
hkfires
52e46ced1b fix(translator): avoid forcing RFC 8259 system prompt 2026-01-19 11:33:27 +08:00
hkfires
cf9daf470c feat(translator): report cached token usage in Claude output 2026-01-19 11:23:44 +08:00
Luis Pater
ac7738bdeb Merge pull request #114 from router-for-me/plus
v6.7.9
2026-01-19 04:03:26 +08:00
Luis Pater
2d9f6c104c Merge branch 'main' into plus 2026-01-19 04:03:17 +08:00
Luis Pater
5d0460ece2 Merge pull request #112 from clstb/main
Add Github Copilot support for management interface
2026-01-19 04:02:09 +08:00
Luis Pater
140d6211cc feat(translator): add reasoning state tracking and improve reasoning summary handling
- Introduced `oaiToResponsesStateReasoning` to track reasoning data.
- Enhanced logic for emitting reasoning summary events and managing state transitions.
- Updated output generation to handle multiple reasoning entries consistently.
2026-01-19 03:58:28 +08:00
Luis Pater
60f9a1442c Merge pull request #1088 from router-for-me/thinking
Thinking
2026-01-18 17:01:59 +08:00
hkfires
cb6caf3f87 fix(thinking): update ValidateConfig to include fromSuffix parameter and adjust budget validation logic 2026-01-18 16:37:14 +08:00
781456868@qq.com
c9301a6d18 docs: update README with new features and Docker deployment guide 2026-01-18 15:07:29 +08:00
781456868@qq.com
0e77e93e5d feat: add Kiro OAuth web, rate limiter, metrics, fingerprint, background refresh and model converter 2026-01-18 15:04:29 +08:00
Luis Pater
99c7abbbf1 Merge pull request #1067 from router-for-me/auth-files
refactor(auth): simplify filename prefixes for qwen and iflow tokens
2026-01-18 13:41:59 +08:00
Luis Pater
8f511ac33c Merge pull request #1076 from sususu98/fix/antigravity-enum-string
fix(antigravity): convert non-string enum values to strings for Gemini API
2026-01-18 13:40:53 +08:00
Luis Pater
1046152119 Merge pull request #1068 from 0xtbug/dev
docs(readme): add ZeroLimit to projects based on CLIProxyAPI
2026-01-18 13:37:50 +08:00
Luis Pater
f88228f1c5 Merge pull request #1081 from router-for-me/thinking
Refine thinking validation and cross‑provider payload conversion
2026-01-18 13:34:28 +08:00
Luis Pater
62e2b672d9 refactor(logging): centralize log directory resolution logic
- Introduced `ResolveLogDirectory` function in `logging` package to standardize log directory determination across components.
- Replaced redundant logic in `server`, `global_logger`, and `handlers` with the new utility function.
2026-01-18 12:40:57 +08:00
hkfires
03005b5d29 refactor(thinking): add Gemini family provider grouping for strict validation 2026-01-18 11:30:53 +08:00
hkfires
c7e8830a56 refactor(thinking): pass source and target formats to ApplyThinking for cross-format validation
Update ApplyThinking signature to accept fromFormat and toFormat parameters
instead of a single provider string. This enables:

- Proper level-to-budget conversion when source is level-based (openai/codex)
  and target is budget-based (gemini/claude)
- Strict budget range validation when source and target formats match
- Level clamping to nearest supported level for cross-format requests
- Format alias resolution in SDK translator registry for codex/openai-response

Also adds ErrBudgetOutOfRange error code and improves iflow config extraction
to fall back to openai format when iflow-specific config is not present.
2026-01-18 10:30:15 +08:00
hkfires
d5ef4a6d15 refactor(translator): remove registry model lookups from thinking config conversions 2026-01-18 10:30:14 +08:00
hkfires
97b67e0e49 test(thinking): split E2E coverage into suffix and body parameter test functions
Refactor thinking configuration tests by separating model name suffix-based
scenarios from request body parameter-based scenarios into distinct test
functions with independent case numbering.

Architectural improvements:
- Extract thinkingTestCase struct to package level for shared usage
- Add getTestModels() helper returning complete model fixture set
- Introduce runThinkingTests() runner with protocol-specific field detection
- Register level-subset-model fixture with constrained low/high level support
- Extend iflow protocol handling for glm-test and minimax-test models
- Add same-protocol strict boundary validation cases (80-89)
- Replace error responses with clamped values for boundary-exceeding budgets
2026-01-18 10:30:14 +08:00
sususu98
dd6d78cb31 fix(antigravity): convert non-string enum values to strings for Gemini API
Gemini API requires all enum values in function declarations to be
strings. Some MCP tools (e.g., roxybrowser) define schemas with numeric
enums like `"enum": [0, 1, 2]`, causing INVALID_ARGUMENT errors.

Add convertEnumValuesToStrings() to automatically convert numeric and
boolean enum values to their string representations during schema
transformation.
2026-01-18 02:00:02 +00:00
Luis Pater
46433a25f8 fix(translator): add check for empty text to prevent invalid serialization in gemini and antigravity 2026-01-18 00:50:10 +08:00
clstb
b4e070697d feat: support github copilot in management ui 2026-01-17 17:22:45 +01:00
Tubagus
c8843edb81 Update README_CN.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-17 11:33:29 +07:00
Tubagus
f89feb881c Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-17 11:33:18 +07:00
Tubagus
dbba71028e docs(readme): add ZeroLimit to projects based on CLIProxyAPI 2026-01-17 11:30:15 +07:00
Tubagus
8549a92e9a docs(readme): add ZeroLimit to projects based on CLIProxyAPI
Added ZeroLimit app to the list of projects in README.
2026-01-17 11:29:22 +07:00
hkfires
109cffc010 refactor(auth): simplify filename prefixes for qwen and iflow tokens 2026-01-17 12:20:58 +08:00
Luis Pater
f8f3ad84fc Fixed: #1064
feat(translator): improve system message handling and content indexing across translators

- Updated logic for processing system messages in `claude`, `gemini`, `gemini-cli`, and `antigravity` translators.
- Introduced indexing for `systemInstruction.parts` to ensure proper ordering and handling of multi-part content.
- Added safeguards for accurate content transformation and serialization.
2026-01-17 05:40:56 +08:00
Luis Pater
93d7883513 Merge pull request #110 from PancakeZik/fix/system-prompt-reinjection
fix: prevent system prompt re-injection on subsequent turns
2026-01-17 05:19:11 +08:00
Luis Pater
015a3e8a83 Merge branch 'router-for-me:main' into main 2026-01-17 05:17:38 +08:00
Luis Pater
bc7167e9fe feat(runtime): add model alias support and enhance payload rule matching
- Introduced `payloadModelAliases` and `payloadModelCandidates` functions to support model aliases for improved flexibility.
- Updated rule matching logic to handle multiple model candidates.
- Refactored variable naming in executor to improve code clarity and consistency.
2026-01-17 05:05:24 +08:00
Luis Pater
384578a88c feat(cliproxy, gemini): improve ID matching logic and enrich normalized model output
- Enhanced ID matching in `cliproxy` by adding additional conditions to better handle ID equality cases.
- Updated `gemini` handlers to include `displayName` and `description` in normalized models for enriched metadata.
2026-01-17 04:44:09 +08:00
Joao
6b074653f2 fix: prevent system prompt re-injection on subsequent turns
When tool results are sent back to the model, the system prompt was being
re-injected into the user message content, causing the model to think the
user had pasted the system prompt again. This was especially noticeable
after multiple tool uses.

The fix checks if there is conversation history (len(history) > 0). If so,
it's a subsequent turn and we skip system prompt injection. The system
prompt is only injected on the first turn (len(history) == 0).

This ensures:
- First turn: system prompt is injected
- Tool result turns: system prompt is NOT re-injected
- New conversations: system prompt is injected fresh
2026-01-16 20:16:44 +00:00
Luis Pater
65b4e1ec6c feat(codex): enable instruction toggling and update role terminology
- Added conditional logic for Codex instruction injection based on configuration.
- Updated role terminology from "user" to "developer" for better alignment with context.
2026-01-17 04:12:29 +08:00
Luis Pater
06afa29f2d Merge branch 'router-for-me:main' into main 2026-01-16 20:01:35 +08:00
Luis Pater
6600d58ba2 feat(codex): enhance input transformation and remove unused safety_identifier field
- Added logic to transform `inputResults` into structured JSON for improved processing.
- Removed redundant `safety_identifier` field in executor payload to streamline requests.
2026-01-16 19:59:01 +08:00
Luis Pater
25e9be3ced Merge pull request #103 from ChrAlpha/feat/add-gpt-5.2-codex-copilot
feat(openai): responses API support for GitHub Copilot provider
2026-01-16 18:33:53 +08:00
Luis Pater
ccb2aaf2fe Merge branch 'router-for-me:main' into main 2026-01-16 18:29:56 +08:00
Luis Pater
961c6f67da Merge pull request #100 from novadev94/fix/readd_kiro_auto
fix(kiro): re-add kiro-auto to registry
2026-01-16 18:29:43 +08:00
Luis Pater
dc4305f75a Merge pull request #107 from zccing/main
fix(kiro): correct Amazon Q endpoint URL path
2026-01-16 18:28:45 +08:00
Chén Mù
4dc7af5a5d Merge pull request #1054 from router-for-me/codex
fix(codex): ensure instructions field exists
2026-01-16 15:40:12 +08:00
hkfires
902bea24b4 fix(codex): ensure instructions field exists 2026-01-16 15:38:10 +08:00
Cc
778cf4af9e feat(kiro): add agent-mode and optout headers for non-IDC auth
- Add x-amzn-kiro-agent-mode: vibe for non-IDC auth (Social, Builder ID)
  IDC auth continues to use "spec" mode
- Add x-amzn-codewhisperer-optout: true for all auth types
  This opts out of data sharing for service improvement (privacy)

These changes align with other Kiro implementations (kiro.rs, KiroGate,
kiro-gateway, AIClient-2-API) and make requests more similar to real
Kiro IDE clients.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 14:21:38 +08:00
hkfires
c3ef46f409 feat(config): supplement missing default aliases during antigravity migration 2026-01-16 13:37:46 +08:00
Cc
4721c58d9c fix(kiro): correct Amazon Q endpoint URL path
The Q endpoint was using `/` which caused all requests to fail with
400 or UnknownOperationException. Changed to `/generateAssistantResponse`
which is the correct path for the Q endpoint.

This fix restores the Q endpoint failover functionality.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 13:22:43 +08:00
Luis Pater
aa0b63e214 refactor(config): clarify Codex instruction toggle documentation 2026-01-16 12:50:09 +08:00
Luis Pater
3c4e7997c3 Merge branch 'router-for-me:main' into main 2026-01-16 12:47:23 +08:00
Luis Pater
1afc3a5f65 feat(auth): add support for kiro OAuth model alias
- Introduced `kiro` channel and alias resolution in `oauth_model_alias` logic.
- Updated supported channels documentation and examples to include `kiro` and `github-copilot`.
- Enhanced unit tests to validate `kiro` alias functionality.
2026-01-16 12:47:05 +08:00
Luis Pater
ea3d22831e refactor(codex): update terminology to "official instructions" for clarity 2026-01-16 12:44:57 +08:00
Luis Pater
3b4d6d359b Merge pull request #1049 from router-for-me/codex
feat(codex): add config toggle for codex instructions injection
2026-01-16 12:38:35 +08:00
hkfires
48cba39a12 feat(codex): add config toggle for codex instructions injection 2026-01-16 12:30:12 +08:00
Luis Pater
bca244df67 Merge branch 'router-for-me:main' into main 2026-01-16 11:37:33 +08:00
Luis Pater
cec4e251bd feat(translator): preserve text field in serialized output during chat completions processing 2026-01-16 11:35:34 +08:00
Luis Pater
526dd866ba refactor(gemini): replace static model handling with dynamic model registry lookup 2026-01-16 10:39:16 +08:00
Luis Pater
c29839d2ed Merge remote-tracking branch 'origin/main' into pr-104
# Conflicts:
#	config.example.yaml
#	internal/config/config.go
#	sdk/cliproxy/auth/model_name_mappings.go
2026-01-16 09:40:07 +08:00
Luis Pater
b31ddc7bf1 Merge branch 'dev' 2026-01-16 08:21:59 +08:00
Chén Mù
22e1ad3d8a Merge pull request #1018 from pikeman20/main
feat(docker): use environment variables for volume paths
2026-01-16 08:19:23 +08:00
Luis Pater
f571b1deb0 feat(config): add support for raw JSON payload rules
- Introduced `default-raw` and `override-raw` rules to handle raw JSON values.
- Enhanced `PayloadConfig` to validate and sanitize raw JSON payload rules.
- Updated executor logic to apply `default-raw` and `override-raw` rules.
- Extended example YAML to include usage of raw JSON rules.
2026-01-16 08:15:28 +08:00
Luis Pater
67f8732683 Merge pull request #1033 from router-for-me/reasoning
Refactor thinking
2026-01-15 20:33:13 +08:00
hkfires
2b387e169b feat(iflow): add iflow-rome model definition 2026-01-15 20:23:55 +08:00
hkfires
199cf480b0 refactor(thinking): remove support for non-standard thinking configurations
This change removes the translation logic for several non-standard, proprietary extensions used to configure thinking/reasoning. Specifically, support for `extra_body.google.thinking_config` and the Anthropic-style `thinking` object has been dropped from the OpenAI request translators.

This simplification streamlines the translators, focusing them on the standard `reasoning_effort` parameter. It also removes the need to look up model information from the registry within these components.

BREAKING CHANGE: Support for non-standard thinking configurations via `extra_body.google.thinking_config` and the Anthropic-style `thinking` object has been removed. Clients should now use the standard `reasoning_effort` parameter to control reasoning.
2026-01-15 19:32:12 +08:00
ChrAlpha
18daa023cb fix(openai): improve error handling for response conversion failures 2026-01-15 19:13:54 +08:00
hkfires
4ad6189487 refactor(thinking): extract antigravity logic into a dedicated provider 2026-01-15 19:08:22 +08:00
ChrAlpha
8950d92682 feat(openai): implement endpoint resolution and response handling for Chat and Responses models 2026-01-15 18:30:01 +08:00
hkfires
fe5b3c80cb refactor(config): rename oauth-model-mappings to oauth-model-alias 2026-01-15 18:03:26 +08:00
ChrAlpha
0ffcce3ec8 feat(registry): add supported endpoints for GitHub Copilot models
Enhance model definitions by including supported API endpoints for each model. This allows for better integration and usage tracking with the GitHub Copilot API.
2026-01-15 16:32:28 +08:00
hkfires
e0ffec885c fix(aistudio): remove levels from model definitions 2026-01-15 16:06:46 +08:00
hkfires
ff4ff6bc2f feat(thinking): support zero as a valid thinking budget for capable models 2026-01-15 15:41:10 +08:00
ChrAlpha
f4fcfc5867 feat(registry): add GPT-5.2-Codex model to GitHub Copilot provider
Add gpt-5.2-codex model definition to GetGitHubCopilotModels() function,
  enabling access to OpenAI GPT-5.2 Codex through the GitHub Copilot API.
2026-01-15 14:14:09 +08:00
Luis Pater
7248f65c36 feat(auth): prevent filestore writes on unchanged metadata
- Added `metadataEqualIgnoringTimestamps` to compare metadata while ignoring volatile fields.
- Prevented redundant writes caused by changes in timestamp-related fields.
- Improved efficiency in filestore operations by skipping unnecessary updates.
2026-01-15 14:05:23 +08:00
hkfires
5c40a2db21 refactor(thinking): simplify ModeNone and budget validation logic 2026-01-15 14:03:08 +08:00
Luis Pater
d6111344c5 Merge branch 'router-for-me:main' into main 2026-01-15 13:30:28 +08:00
Luis Pater
086eb3df7a refactor(auth): simplify file handling logic and remove redundant comparison functions
feat(auth): fetch and update Antigravity project ID from metadata during filestore operations

- Added support to retrieve and update `project_id` using the access token if missing in metadata.
- Integrated HTTP client to fetch project ID dynamically.
- Enhanced metadata persistence logic.
2026-01-15 13:29:14 +08:00
hkfires
ee2976cca0 refactor(thinking): improve logging for user-defined models 2026-01-15 13:06:41 +08:00
hkfires
8bc6df329f fix(auth): apply API key model mapping to request model 2026-01-15 13:06:41 +08:00
hkfires
bcd4d9595f fix(thinking): refine ModeNone handling based on provider capabilities 2026-01-15 13:06:41 +08:00
hkfires
5a77b7728e refactor(thinking): improve budget clamping and logging with provider/model context 2026-01-15 13:06:41 +08:00
hkfires
1fbbba6f59 feat(logging): order log fields for improved readability 2026-01-15 13:06:41 +08:00
hkfires
847be0e99d fix(auth): use base model name for auth matching by stripping suffix 2026-01-15 13:06:41 +08:00
hkfires
f6a2d072e6 refactor(thinking): refine configuration logging 2026-01-15 13:06:41 +08:00
hkfires
ed8b0f25ee fix(thinking): use LookupModelInfo for model data 2026-01-15 13:06:41 +08:00
hkfires
6e4a602c60 fix(thinking): map reasoning_effort to thinkingConfig 2026-01-15 13:06:40 +08:00
hkfires
2262479365 refactor(thinking): remove legacy utilities and simplify model mapping 2026-01-15 13:06:40 +08:00
hkfires
33d66959e9 test(thinking): remove legacy unit and integration tests 2026-01-15 13:06:40 +08:00
hkfires
7f1b2b3f6e fix(thinking): improve model lookup and validation 2026-01-15 13:06:40 +08:00
hkfires
40ee065eff fix(thinking): use static lookup to avoid alias issues 2026-01-15 13:06:40 +08:00
hkfires
a75fb6af90 refactor(antigravity): remove hardcoded model aliases 2026-01-15 13:06:39 +08:00
hkfires
72f2125668 fix(executor): properly handle thinking application errors 2026-01-15 13:06:39 +08:00
hkfires
e8f5888d8e fix(thinking): fix auth matching for thinking suffix and json field conflicts 2026-01-15 13:06:39 +08:00
hkfires
0b06d637e7 refactor: improve thinking logic 2026-01-15 13:06:39 +08:00
Luis Pater
496f6770a5 Merge branch 'router-for-me:main' into main 2026-01-15 12:09:22 +08:00
Luis Pater
5a7e5bd870 feat(auth): add Antigravity onboarding with tier selection
- Updated `ideType` to `ANTIGRAVITY` in request payload.
- Introduced tier-selection logic to determine default tier for onboarding.
- Added `antigravityOnboardUser` function for project ID retrieval via polling.
- Enhanced error handling and response decoding for onboarding flow.
2026-01-15 11:43:02 +08:00
Luis Pater
6f8a8f8136 feat(selector): add priority support for auth selection 2026-01-15 07:08:24 +08:00
pikeman20
5df195ea82 feat(docker): use environment variables for volume paths
This change introduces environment variable interpolation for volume paths, allowing users to customize where configuration, authentication, and log data are stored.

Why: Makes the project easier to deploy on various hosting environments that require decoupled data management without needing to modify the core docker-compose.yml..

Key points:

Defaults to existing paths (./config.yaml, ./auths, ./logs) to ensure zero breaking changes for current users.

Follows the existing naming convention used in the project.

Enhances portability for CI/CD and cloud-native deployments.
2026-01-15 05:42:51 +07:00
Nova
f82f70df5c fix(kiro): re-add kiro-auto to registry
Reference: https://github.com/router-for-me/CLIProxyAPIPlus/pull/16
Revert: a594338bc5
2026-01-15 03:26:22 +07:00
Luis Pater
5a2bf191fc Merge pull request #98 from router-for-me/plus
v6.6.105
2026-01-15 03:31:04 +08:00
Luis Pater
a235fb1507 Merge branch 'main' into plus 2026-01-15 03:30:56 +08:00
Luis Pater
0d66522ed8 Merge pull request #95 from ZqinKing/main
feat(kiro): 实现动态工具压缩功能
2026-01-15 03:29:49 +08:00
Luis Pater
b163f8ed9e Fixed: #1004
feat(translator): add function name to response output item serialization

- Included `item.name` in the serialized response output to enhance output item handling.
2026-01-15 03:27:00 +08:00
ZqinKing
83e5f60b8b fix(kiro): scale description compression by needed size
Compute a size-reduction based keep ratio and use it to trim
tool descriptions, avoiding forced minimum truncation when the
target size already fits. This aligns compression with actual
payload reduction needs and prevents over-compression.
2026-01-14 16:22:46 +08:00
ZqinKing
5b433f962f feat(kiro): 实现动态工具压缩功能
## 背景
当 Claude Code 发送过多工具信息时,可能超出 Kiro API 请求限制导致 500 错误。
现有的工具描述截断(KiroMaxToolDescLen = 10237)只能限制单个工具的描述长度,
无法解决整体工具列表过大的问题。

## 解决方案
实现动态工具压缩功能,采用两步压缩策略:
1. 先检查原始大小,超过 20KB 才进行压缩
2. 第一步:简化 input_schema,只保留 type/enum/required 字段
3. 第二步:按比例缩短 description(最短 50 字符)
4. 保留全部工具和 skills 可调用,不丢弃任何工具

## 新增文件
- internal/translator/kiro/claude/tool_compression.go
  - calculateToolsSize(): 计算工具列表的 JSON 序列化大小
  - simplifyInputSchema(): 简化 input_schema,递归处理嵌套 properties
  - compressToolDescription(): 按比例压缩描述,支持 UTF-8 安全截断
  - compressToolsIfNeeded(): 主压缩函数,实现两步压缩策略

- internal/translator/kiro/claude/tool_compression_test.go
  - 完整的单元测试覆盖所有新增函数
  - 测试 UTF-8 安全性
  - 测试压缩效果

## 修改文件
- internal/translator/kiro/common/constants.go
  - 新增 ToolCompressionTargetSize = 20KB (压缩目标大小阈值)
  - 新增 MinToolDescriptionLength = 50 (描述最短长度)

- internal/translator/kiro/claude/kiro_claude_request.go
  - 在 convertClaudeToolsToKiro() 函数末尾调用 compressToolsIfNeeded()

## 测试结果
- 70KB 工具压缩至 17KB (74.7% 压缩率)
- 所有单元测试通过

## 预期效果
- 80KB+ tools 压缩至 ~15KB
- 不影响工具调用功能
2026-01-14 11:07:07 +08:00
Luis Pater
a1da6ff5ac Fixed: #499 #985
feat(oauth): add support for customizable OAuth callback ports

- Introduced `oauth-callback-port` flag to override default callback ports.
- Updated SDK and login flows for `iflow`, `gemini`, `antigravity`, `codex`, `claude`, and `openai` to respect configurable callback ports.
- Refactored internal OAuth servers to dynamically assign ports based on the provided options.
- Revised tests and documentation to reflect the new flag and behavior.
2026-01-14 04:29:15 +08:00
adrenjc
5977af96a0 fix(antigravity): prevent corrupted thought signature when switching models
When switching from Claude models (e.g., Opus 4.5) to Gemini models
(e.g., Flash) mid-conversation via Antigravity OAuth, the client-provided
thinking signatures from Claude would cause "Corrupted thought signature"
errors since they are incompatible with Gemini API.

Changes:
- Remove fallback to client-provided signatures in thinking block handling
- Only use cached signatures (from same-session Gemini responses)
- Skip thinking blocks without valid cached signatures
- tool_use blocks continue to use skip_thought_signature_validator when
  no valid signature is available

This ensures cross-model switching works correctly while preserving
signature validation for same-model conversations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 18:24:05 +08:00
Luis Pater
9b33fbf1cd Merge branch 'router-for-me:main' into main 2026-01-13 03:15:46 +08:00
Luis Pater
43652d044c refactor(config): replace nonstream-keepalive with nonstream-keepalive-interval
- Updated `SDKConfig` to use `nonstream-keepalive-interval` (seconds) instead of the boolean `nonstream-keepalive`.
- Refactored handlers and logic to incorporate the new interval-based configuration.
- Updated config diff, tests, and example YAML to reflect the changes.
2026-01-13 03:14:38 +08:00
Luis Pater
b1b379ea18 feat(api): add non-streaming keep-alive support for idle timeout prevention
- Introduced `StartNonStreamingKeepAlive` to emit periodic blank lines during non-streaming responses.
- Added `nonstream-keepalive` configuration option in `SDKConfig`.
- Updated handlers to utilize `StartNonStreamingKeepAlive` and ensure proper cleanup.
- Extended config diff and tests to include `nonstream-keepalive` changes.
2026-01-13 02:36:07 +08:00
hkfires
21ac161b21 fix(test): implement missing HttpRequest method in stream bootstrap mock 2026-01-12 16:33:43 +08:00
Luis Pater
94e979865e Fixed: #897
refactor(executor): remove `prompt_cache_retention` from request payloads
2026-01-12 10:46:47 +08:00
Luis Pater
6c324f2c8b Fixed: #936
feat(cliproxy): support multiple aliases for OAuth model mappings

- Updated mapping logic to allow multiple aliases per upstream model name.
- Adjusted `SanitizeOAuthModelMappings` to ensure aliases remain unique within channels.
- Added test cases to validate multi-alias scenarios.
- Updated example config to clarify multi-alias support.
2026-01-12 10:40:34 +08:00
Luis Pater
e0194d8511 fix(ci): revert Docker image build and push workflow for tagging releases 2026-01-12 00:29:34 +08:00
Luis Pater
216dafe44b Merge branch 'router-for-me:main' into main 2026-01-12 00:27:15 +08:00
Luis Pater
d7dc9660af Merge pull request #93 from jc01rho/main
feat(config): add github-copilot support to oauth-model-mappings and oauth-excluded-models
2026-01-12 00:26:58 +08:00
jc01rho
e0e30df323 Delete .github/workflows/docker-image.yml 2026-01-12 01:22:13 +09:00
Luis Pater
543dfd67e0 refactor(cache): remove max entries logic and extend signature TTL to 3 hours 2026-01-12 00:20:44 +08:00
jc01rho
bbd3eafde0 Delete .github/workflows/auto-sync.yml 2026-01-12 01:19:49 +09:00
jc01rho
e9cd355893 Add auto-sync workflow configuration file 2026-01-12 01:11:11 +09:00
jc01rho
c3e39267b8 Create auto-sync 2026-01-12 01:10:58 +09:00
Woohyun Rho
b477aff611 fix(login): use response project ID when API returns different project 2026-01-12 01:05:57 +09:00
Luis Pater
28bd1323a2 Merge pull request #971 from router-for-me/codex
feat(codex): add OpenCode instructions based on user agent
2026-01-11 16:01:13 +08:00
hkfires
220ca45f74 fix(codex): only override instructions when upstream provides them 2026-01-11 15:52:21 +08:00
hkfires
70a82d80ac fix(codex): only override instructions in responses for OpenCode UA 2026-01-11 15:19:37 +08:00
hkfires
ac626111ac feat(codex): add OpenCode instructions based on user agent 2026-01-11 13:36:35 +08:00
Woohyun Rho
8f6740fcef fix(iflow): add missing applyExcludedModels call for iflow provider 2026-01-11 03:01:50 +09:00
Woohyun Rho
d829ac4cf7 docs(config): add github-copilot and kiro to oauth-excluded-models documentation 2026-01-11 02:48:05 +09:00
Woohyun Rho
f064f6e59d feat(config): add github-copilot to oauth-model-mappings supported channels 2026-01-11 01:59:38 +09:00
extremk
5bb9c2a2bd Add candidate count parameter to OpenAI request 2026-01-10 18:50:13 +08:00
extremk
0b5bbe9234 Add candidate count handling in OpenAI request 2026-01-10 18:49:29 +08:00
extremk
14c74e5e84 Handle 'n' parameter for candidate count in requests
Added handling for the 'n' parameter to set candidate count in generationConfig.
2026-01-10 18:48:33 +08:00
extremk
6448d0ee7c Add candidate count handling in OpenAI request 2026-01-10 18:47:41 +08:00
extremk
b0c17af2cf Enhance Gemini to OpenAI response conversion
Refactor response handling to support multiple candidates and improve parameter management.
2026-01-10 18:46:25 +08:00
Luis Pater
8f27fd5c42 feat(executor): add HttpRequest method with credential injection for GitHub Copilot and Kiro executors 2026-01-10 16:44:58 +08:00
Luis Pater
a9823ba58a Merge branch 'router-for-me:main' into main 2026-01-10 16:27:52 +08:00
Luis Pater
8cfe26f10c Merge branch 'sdk' into dev 2026-01-10 16:26:23 +08:00
Luis Pater
80db2dc254 Merge pull request #955 from router-for-me/api
feat(codex): add subscription date fields to ID token claims
2026-01-10 16:26:07 +08:00
Luis Pater
e8e3bc8616 feat(executor): add HttpRequest support across executors for better http request handling 2026-01-10 16:25:25 +08:00
Luis Pater
ab5f5386e4 Merge branch 'router-for-me:main' into main 2026-01-10 14:53:04 +08:00
Luis Pater
bc3195c8d8 refactor(logger): remove unnecessary request details limit logic 2026-01-10 14:46:59 +08:00
hkfires
6494330c6b feat(codex): add subscription date fields to ID token claims 2026-01-10 11:15:20 +08:00
Luis Pater
89e34bf1e6 Merge pull request #82 from FakerL/feat/kiro-oauth-model-mappings
feat(kiro): add OAuth model name mappings support for Kiro
2026-01-10 05:43:16 +08:00
Luis Pater
2574eec2ed Merge pull request #92 from router-for-me/main
v6.6.96
2026-01-10 01:15:21 +08:00
Luis Pater
514b9bf9fc Merge origin/main into pr-92 2026-01-10 01:12:22 +08:00
Luis Pater
4d7f389b69 Fixed: #941
fix(translator): ensure fallback to valid originalRequestRawJSON in response handling
2026-01-10 01:01:09 +08:00
Luis Pater
95f87d5669 Merge pull request #947 from pykancha/fix-memory-leak
Resolve memory leaks causing OOM in k8s deployment
2026-01-10 00:40:47 +08:00
Luis Pater
c83365a349 Merge pull request #938 from router-for-me/log
refactor(logging): clean up oauth logs and debugs
2026-01-10 00:02:45 +08:00
Luis Pater
6b3604cf2b Merge pull request #943 from ben-vargas/fix-tool-mappings
Fix Claude OAuth tool name mapping (proxy_)
2026-01-09 23:52:29 +08:00
Luis Pater
af6bdca14f Fixed: #942
fix(executor): ignore non-SSE lines in OpenAI-compatible streams
2026-01-09 23:41:50 +08:00
Luis Pater
58d45b4d58 Merge pull request #91 from router-for-me/plus
v6.6.93
2026-01-09 21:52:31 +08:00
Luis Pater
1906ebcfce Merge branch 'main' into plus 2026-01-09 21:52:24 +08:00
hemanta212
1c773c428f fix: Remove investigation artifacts 2026-01-09 17:47:59 +05:45
Ben Vargas
e785bfcd12 Use unprefixed Claude request for translation
Keep the upstream payload prefixed for OAuth while passing the unprefixed request body into response translators. This avoids proxy_ leaking into OpenAI Responses echoed tool metadata while preserving the Claude OAuth workaround.
2026-01-09 00:54:35 -07:00
hemanta212
47dacce6ea fix(server): resolve memory leaks causing OOM in k8s deployment
- usage/logger_plugin: cap modelStats.Details at 1000 entries per model
- cache/signature_cache: add background cleanup for expired sessions (10 min)
- management/handler: add background cleanup for stale IP rate-limit entries (1 hr)
- executor/cache_helpers: add mutex protection and TTL cleanup for codexCacheMap (15 min)
- executor/codex_executor: use thread-safe cache accessors

Add reproduction tests demonstrating leak behavior before/after fixes.

Amp-Thread-ID: https://ampcode.com/threads/T-019ba0fc-1d7b-7338-8e1d-ca0520412777
Co-authored-by: Amp <amp@ampcode.com>
2026-01-09 13:33:46 +05:45
Ben Vargas
dcac3407ab Fix Claude OAuth tool name mapping
Prefix tool names with proxy_ for Claude OAuth requests and strip the prefix from streaming and non-streaming responses to restore client-facing names.

Updates the Claude executor to:
- add prefixing for tools, tool_choice, and tool_use messages when using OAuth tokens
- strip the prefix from tool_use events in SSE and non-streaming payloads
- add focused unit tests for prefix/strip helpers
2026-01-09 00:10:38 -07:00
hkfires
7004295e1d build(docker): move stats export execution after image build 2026-01-09 11:24:00 +08:00
hkfires
ee62ef4745 refactor(logging): clean up oauth logs and debugs 2026-01-09 11:20:55 +08:00
Luis Pater
ef6bafbf7e fix(executor): handle context cancellation and deadline errors explicitly 2026-01-09 10:48:29 +08:00
Luis Pater
ed28b71e87 refactor(amp): remove duplicate comments in response rewriter 2026-01-09 08:21:13 +08:00
Luis Pater
d47b7dc79a refactor(response): enhance parameter handling for Codex to Claude conversion 2026-01-09 05:20:19 +08:00
Luis Pater
49b9709ce5 Merge pull request #787 from sususu98/fix/antigravity-429-retry-delay-parsing
fix(antigravity): parse retry-after delay from 429 response body
2026-01-09 04:45:25 +08:00
Luis Pater
a2eba2cdf5 Merge pull request #763 from mvelbaum/feature/improve-oauth-use-logging
feat(logging): disambiguate OAuth credential selection in debug logs
2026-01-09 04:43:21 +08:00
Luis Pater
3d01b3cfe8 Merge pull request #553 from XInTheDark/fix/builtin-tools-web-search
fix(translator): preserve built-in tools (web_search) to Responses API
2026-01-09 04:40:13 +08:00
Luis Pater
af2efa6f7e Merge pull request #605 from soilSpoon/feature/amp-compat
feature: Improves Amp client compatibility
2026-01-09 04:28:17 +08:00
Luis Pater
d73b61d367 Merge pull request #901 from uzhao/vscode-plugin
Vscode plugin
2026-01-08 22:22:27 +08:00
Luis Pater
d3533f81fc Merge branch 'router-for-me:main' into main 2026-01-08 21:06:24 +08:00
Luis Pater
59a448b645 feat(executor): centralize systemInstruction handling for Claude and Gemini-3-Pro models 2026-01-08 21:05:33 +08:00
Luis Pater
3de7a7f0cd Merge branch 'router-for-me:main' into main 2026-01-08 20:32:08 +08:00
Chén Mù
4adb9eed77 Merge pull request #921 from router-for-me/atgy
fix(executor): update gemini model identifier to gemini-3-pro-preview
2026-01-08 19:20:32 +08:00
hkfires
b6a0f7a07f fix(executor): update gemini model identifier to gemini-3-pro-preview
Update the model name check in `buildRequest` to target "gemini-3-pro-preview" instead of "gemini-3-pro" when applying specific system instruction handling.
2026-01-08 19:14:52 +08:00
Luis Pater
b2566368f8 Merge branch 'router-for-me:main' into main 2026-01-08 12:45:39 +08:00
Luis Pater
1b2f907671 feat(executor): update system instruction handling for Claude and Gemini-3-Pro models 2026-01-08 12:42:26 +08:00
Luis Pater
bda04eed8a feat(executor): add model-specific support for "gemini-3-pro" in execution and payload handling 2026-01-08 12:27:03 +08:00
Luis Pater
e0735977b5 Merge branch 'router-for-me:main' into main 2026-01-08 11:17:28 +08:00
Luis Pater
67985d8226 feat(executor): enhance Antigravity payload with user role and dynamic system instructions 2026-01-08 10:55:25 +08:00
Jianyang Zhao
cbcb061812 Update README_CN.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:07:01 -05:00
Jianyang Zhao
9fc2e1b3c8 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 20:06:55 -05:00
Jianyang Zhao
3b484aea9e Add Claude Proxy VSCode to README_CN.md
Added information about Claude Proxy VSCode extension.
2026-01-07 20:03:07 -05:00
Jianyang Zhao
963a0950fa Add Claude Proxy VSCode extension to README
Added Claude Proxy VSCode extension to the README.
2026-01-07 20:02:50 -05:00
Luis Pater
1fb4f2b12e Merge branch 'router-for-me:main' into main 2026-01-07 18:18:15 +08:00
Luis Pater
f4ba1ab910 fix(executor): remove unused tokenRefreshTimeout constant and pass zero timeout to HTTP client 2026-01-07 18:16:49 +08:00
Luis Pater
2662f91082 feat(management): add PostOAuthCallback handler to token requester interface 2026-01-07 10:47:32 +08:00
Luis Pater
f5967069f2 docs: remove 9Router from community projects in README 2026-01-07 02:58:49 +08:00
Luis Pater
80f5523685 Merge branch 'router-for-me:main' into main 2026-01-07 01:24:12 +08:00
Luis Pater
c1db2c7d7c Merge pull request #888 from router-for-me/api-call-TOKEN-fix
fix(management): refresh antigravity token for api-call $TOKEN$
2026-01-07 01:19:24 +08:00
LTbinglingfeng
5e5d8142f9 fix(auth): error when antigravity refresh token missing during refresh 2026-01-07 01:09:50 +08:00
LTbinglingfeng
b01619b441 fix(management): refresh antigravity token for api-call $TOKEN$ 2026-01-07 00:14:02 +08:00
Luis Pater
109cf3928a Merge pull request #88 from router-for-me/plus
v6.6.85
2026-01-06 23:20:47 +08:00
Luis Pater
4794645dec Merge branch 'main' into plus 2026-01-06 23:20:39 +08:00
Luis Pater
f861bd6a94 docs: add 9Router to community projects in README 2026-01-06 23:15:28 +08:00
Luis Pater
6dbfdd140d Merge pull request #871 from decolua/patch-1
Update README.md
2026-01-06 22:58:53 +08:00
zhiqing0205
aa8526edc0 fix(codex): use unicode title casing for plan 2026-01-06 10:24:02 +08:00
zhiqing0205
ac3ca0ad8e feat(codex): include plan type in auth filename 2026-01-06 02:25:56 +08:00
decolua
386ccffed4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 20:54:33 +07:00
FakerL
08d21b76e2 Update sdk/auth/filestore.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 21:38:26 +08:00
decolua
ffddd1c90a Update README.md 2026-01-05 20:29:26 +07:00
Zhi Yang
33aa665555 fix(auth): persist access_token on refresh for providers that need it
Previously, metadataEqualIgnoringTimestamps() ignored access_token for all
providers, which prevented refreshed tokens from being persisted to disk/database.
This caused tokens to be lost on server restart for providers like iFlow.

This change makes the behavior provider-specific:
- Providers like gemini/gemini-cli that issue new tokens on every refresh and
  can re-fetch when needed will continue to ignore access_token (optimization)
- Other providers like iFlow will now persist access_token changes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-05 13:25:46 +00:00
maoring24
00280b6fe8 feat(claude): add native request cloaking for non-claude-code clients
integrate claude-cloak functionality to disguise api requests:
- add CloakConfig with mode (auto/always/never) and strict-mode options
- generate fake user_id in claude code format (user_[hex]_account__session_[uuid])
- inject claude code system prompt (configurable strict mode)
- obfuscate sensitive words with zero-width characters
- auto-detect claude code clients via user-agent

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 20:32:51 +08:00
Zhi Yang
08e8fddf73 feat(kiro): add OAuth model name mappings support for Kiro
Add Kiro to the list of supported channels for OAuth model name mappings,
allowing users to map Kiro model IDs (e.g., kiro-claude-opus-4-5) to
canonical model names (e.g., claude-opus-4-5-20251101).

The Kiro case is implemented as a separate switch block to keep it
isolated from upstream CLIProxyAPI providers, making future merges
from the upstream repository cleaner.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 07:32:08 +00:00
Luis Pater
8f8dfd081b Merge pull request #850 from can1357/main
feat(translator): add developer role support for Gemini translators
2026-01-05 11:27:24 +08:00
Luis Pater
9f1b445c7c docs: add ProxyPilot to community projects in Chinese README 2026-01-05 11:23:48 +08:00
Luis Pater
ae933dfe14 Merge pull request #858 from Finesssee/add-proxypilot
docs: add ProxyPilot to community projects
2026-01-05 11:20:52 +08:00
Luis Pater
5d33d6b8ea Merge branch 'router-for-me:main' into main 2026-01-05 10:42:32 +08:00
Luis Pater
e124db723b Merge pull request #862 from router-for-me/gemini
fix(gemini): abort default injection on existing thinking keys
2026-01-05 10:41:07 +08:00
hkfires
05444cf32d fix(gemini): abort default injection on existing thinking keys 2026-01-05 10:24:30 +08:00
Luis Pater
478aff1189 Merge branch 'router-for-me:main' into main 2026-01-05 09:26:23 +08:00
Luis Pater
8edbda57cf feat(translator): add thoughtSignature to node parts for Gemini and Antigravity requests
Enhanced node structure by including `thoughtSignature` for inline data parts in Gemini OpenAI, Gemini CLI, and Antigravity request handlers to improve traceability of thought processes.
2026-01-05 09:25:17 +08:00
CodeIgnitor
52760a4eaa fix(auth): use backend project ID for free tier Gemini CLI OAuth users
Fixes issue where free tier users cannot access Gemini 3 preview models
due to frontend/backend project ID mapping.

## Problem
Google's Gemini API uses a frontend/backend project mapping system for
free tier users:
- Frontend projects (e.g., gen-lang-client-*) are user-visible
- Backend projects (e.g., mystical-victor-*) host actual API access
- Only backend projects have access to preview models (gemini-3-*)

Previously, CLIProxyAPI ignored the backend project ID returned by
Google's onboarding API and kept using the frontend ID, preventing
access to preview models.

## Solution
### CLI (internal/cmd/login.go)
- Detect free tier users (gen-lang-client-* projects or FREE/LEGACY tier)
- Show interactive prompt allowing users to choose frontend or backend
- Default to backend (recommended for preview model access)
- Pro users: maintain original behavior (keep frontend ID)

### Web UI (internal/api/handlers/management/auth_files.go)
- Detect free tier users using same logic
- Automatically use backend project ID (recommended choice)
- Pro users: maintain original behavior (keep frontend ID)

### Deduplication (internal/cmd/login.go)
- Add deduplication when user selects ALL projects
- Prevents redundant API calls when multiple frontend projects map to
  same backend
- Skips duplicate project IDs in activation loop

## Impact
- Free tier users: Can now access gemini-3-pro-preview and
  gemini-3-flash-preview models
- Pro users: No change in behavior (backward compatible)
- Only affects Gemini CLI OAuth (not antigravity or API key auth)

## Testing
- Tested with free tier account selecting single project
- Tested with free tier account selecting ALL projects
- Verified deduplication prevents redundant onboarding calls
- Confirmed pro user behavior unchanged
2026-01-05 02:41:24 +05:00
Finessse
821249a5ed docs: add ProxyPilot to community projects 2026-01-04 18:19:41 +07:00
Luis Pater
2331b9a2e7 Merge branch 'router-for-me:main' into main 2026-01-04 18:10:49 +08:00
Luis Pater
ee33863b47 Merge pull request #857 from router-for-me/management-update
Management update
2026-01-04 18:07:13 +08:00
Supra4E8C
cd22c849e2 feat(management): 更新OAuth模型映射的清理逻辑以增强数据安全性 2026-01-04 17:57:34 +08:00
Supra4E8C
f0e73efda2 feat(management): add vertex api key and oauth model mappings endpoints 2026-01-04 17:32:00 +08:00
Supra4E8C
3156109c71 feat(management): 支持管理接口调整日志大小/强制前缀/路由策略 2026-01-04 12:21:49 +08:00
can1357
6762e081f3 feat(translator): add developer role support for Gemini translators
Treat OpenAI's "developer" role the same as "system" role in request
translation for gemini, gemini-cli, and antigravity backends.
2026-01-03 21:01:01 +01:00
Luis Pater
5ca3508284 Merge pull request #80 from router-for-me/plus
v6.6.81
2026-01-04 01:38:47 +08:00
Luis Pater
771fec9447 Merge branch 'main' into plus 2026-01-04 01:38:40 +08:00
Luis Pater
7815ee338d fix(translator): adjust message_delta emission boundary in Claude-to-OpenAI conversion
Fixed incorrect boundary logic for `message_delta` emission, ensuring proper handling of usage updates and `emitMessageStopIfNeeded` within the response loop.
2026-01-04 01:36:51 +08:00
Luis Pater
44b6c872e2 feat(config): add support for Fork in OAuth model mappings with alias handling
Implemented `Fork` flag in `ModelNameMapping` to allow aliases as additional models while preserving the original model ID. Updated the `applyOAuthModelMappings` logic, added tests for `Fork` behavior, and updated documentation and examples accordingly.
2026-01-04 01:18:29 +08:00
Luis Pater
7a77b23f2d feat(executor): add token refresh timeout and improve context handling during refresh
Introduced `tokenRefreshTimeout` constant for token refresh operations and enhanced context propagation for `refreshToken` by embedding roundtrip information if available. Adjusted `refreshAuth` to ensure default context initialization and handle cancellation errors appropriately.
2026-01-04 00:26:08 +08:00
Luis Pater
672e8549c0 docs: reorganize README to adjust CodMate placement
Moved CodMate entry under ProxyPal in both English and Chinese README files for consistency in structure and better readability.
2026-01-03 21:31:53 +08:00
Luis Pater
66f5269a23 Merge pull request #837 from loocor/main
docs: add CodMate to community projects
2026-01-03 21:30:15 +08:00
Luis Pater
4eaf769894 Merge branch 'router-for-me:main' into main 2026-01-03 04:55:26 +08:00
Luis Pater
ebec293497 feat(api): integrate TokenStore for improved auth entry management
Replaced file-based auth entry counting with `TokenStore`-backed implementation, enhancing flexibility and context-aware token management. Updated related logic to reflect this change.
2026-01-03 04:53:47 +08:00
Luis Pater
e02ceecd35 feat(registry): introduce ModelRegistryHook for monitoring model registrations and unregistrations
Added support for external hooks to observe model registry events using the `ModelRegistryHook` interface. Implemented thread-safe, non-blocking execution of hooks with panic recovery. Comprehensive tests added to verify hook behavior during registration, unregistration, blocking, and panic scenarios.
2026-01-02 23:18:40 +08:00
Luis Pater
9116392a45 Merge branch 'router-for-me:main' into main 2026-01-02 20:49:51 +08:00
Luis Pater
c8b33a8cc3 Merge pull request #824 from router-for-me/script
feat(script): add usage statistics preservation across container rebuilds
2026-01-02 20:42:25 +08:00
Loocor
dca8d5ded8 Add CodMate app information to README
Added CodMate section to README with app details.
2026-01-02 17:15:38 +08:00
Loocor
2a7fd1e897 Add CodMate description to README_CN.md
添加 CodMate 应用的描述,提供 CLI AI 会话管理功能。
2026-01-02 17:15:09 +08:00
Luis Pater
b9d1e70ac2 Merge pull request #830 from router-for-me/gemini
fix(util): disable default thinking for gemini-3 series
2026-01-02 10:59:24 +08:00
hkfires
fdf5720217 fix(gemini): remove default thinking for gemini 3 models 2026-01-02 10:55:59 +08:00
hkfires
f40bd0cd51 feat(script): add usage statistics preservation across container rebuilds 2026-01-02 10:01:20 +08:00
hkfires
e33676bb87 fix(util): disable default thinking for gemini-3 series 2026-01-02 09:43:40 +08:00
Luis Pater
b1f1cee1e5 feat(executor): refine payload handling by integrating original request context
Updated `applyPayloadConfig` to `applyPayloadConfigWithRoot` across payload translation logic, enabling validation against the original request payload when available. Added support for improved model normalization and translation consistency.
2026-01-02 03:28:37 +08:00
Luis Pater
a1ecc9ab00 Merge branch 'router-for-me:main' into main 2026-01-02 00:04:50 +08:00
Luis Pater
2a663d5cba feat(executor): enhance payload translation with original request context
Refactored `applyPayloadConfig` to `applyPayloadConfigWithRoot`, adding support for default rule validation against the original payload when available. Updated all executors to use `applyPayloadConfigWithRoot` and incorporate an optional original request payload for translations.
2026-01-02 00:03:26 +08:00
Luis Pater
ba486ca6b7 Merge branch 'router-for-me:main' into main 2026-01-01 21:03:16 +08:00
Luis Pater
750b930679 Merge pull request #823 from router-for-me/translator
feat(translator): enhance Claude-to-OpenAI conversion with thinking block and tool result handling
2026-01-01 20:16:10 +08:00
hkfires
3902fd7501 fix(iflow): remove thinking field from request body in thinking config handler 2026-01-01 19:40:28 +08:00
hkfires
4fc3d5e935 refactor(iflow): simplify thinking config handling for GLM and MiniMax models 2026-01-01 19:31:08 +08:00
hkfires
2d2f4572a7 fix(translator): remove unnecessary whitespace trimming in reasoning text collection 2026-01-01 12:39:09 +08:00
hkfires
8f4c46f38d fix(translator): emit tool_result messages before user content in Claude-to-OpenAI conversion 2026-01-01 11:11:43 +08:00
hkfires
b6ba51bc2a feat(translator): add thinking block and tool result handling for Claude-to-OpenAI conversion 2026-01-01 09:41:25 +08:00
Luis Pater
6a66d32d37 Merge pull request #803 from HsnSaboor/fix-invalid-function-names-sanitization-v2
feat(translator): resolve invalid function name errors by sanitizing Claude tool names
2026-01-01 01:15:50 +08:00
Luis Pater
8d15723195 feat(registry): add GetAvailableModelsByProvider method for retrieving models by provider 2025-12-31 23:37:46 +08:00
Chén Mù
736e0aae86 Merge pull request #814 from router-for-me/aistudio
Fix model alias thinking suffix
2025-12-31 03:08:05 -08:00
hkfires
8bf3305b2b fix(thinking): fallback to upstream model for thinking support when alias not in registry 2025-12-31 18:07:13 +08:00
hkfires
d00e3ea973 feat(thinking): add numeric budget to thinkingLevel conversion fallback 2025-12-31 17:14:47 +08:00
hkfires
89db4e9481 fix(thinking): use model alias for thinking config resolution in mapped models 2025-12-31 17:09:22 +08:00
hkfires
e332419081 feat(registry): add thinking support for gemini-2.5-computer-use-preview model 2025-12-31 17:09:22 +08:00
Luis Pater
19232c6388 Merge branch 'router-for-me:main' into main 2025-12-31 11:52:49 +08:00
Luis Pater
e998b1229a feat(updater): add fallback URL and logic for missing management asset 2025-12-31 11:51:20 +08:00
Luis Pater
bbed134bd1 feat(api): add GetAuthStatus method to ManagementTokenRequester interface 2025-12-31 09:40:48 +08:00
Saboor Hassan
47b9503112 chore: revert changes to internal/translator to comply with path guard
This commit reverts all modifications within internal/translator. A separate issue
will be created for the maintenance team to integrate SanitizeFunctionName into
the translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:19:26 +05:00
Saboor Hassan
3b9253c2be fix(translator): resolve invalid function name errors by sanitizing Claude tool names
This commit centralizes tool name sanitization in SanitizeFunctionName,
applying character compliance, starting character rules, and length limits.
It also fixes a regression in gemini_schema tests and preserves MCP-specific
shortening logic while ensuring compliance.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 02:14:46 +05:00
Saboor Hassan
d241359153 fix(translator): address PR feedback for tool name sanitization
- Pre-compile sanitization regex for better performance.
- Optimize SanitizeFunctionName for conciseness and correctness.
- Handle 64-char edge cases by truncating before prepending underscore.
- Fix bug in Antigravity translator (incorrect join index).
- Refactor Gemini translators to avoid redundant sanitization calls.
- Add comprehensive unit tests including 64-char edge cases.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:54:41 +05:00
Saboor Hassan
f4d4249ba5 feat(translator): sanitize tool/function names for upstream provider compatibility
Implemented SanitizeFunctionName utility to ensure Claude tool names meet
Gemini/Upstream strict naming conventions (alphanumeric, starts with letter/underscore, max 64 chars).
Applied sanitization to tool definitions and usage in all relevant translators.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-12-31 01:41:07 +05:00
Luis Pater
06075c0e93 Merge pull request #75 from router-for-me/plus
v6.6.71
2025-12-30 23:34:44 +08:00
Luis Pater
e85c9d9322 Merge branch 'main' into plus 2025-12-30 23:33:35 +08:00
Chén Mù
cb56cb250e Merge pull request #800 from router-for-me/modelmappings
feat(watcher): add model mappings change detection
2025-12-30 06:50:42 -08:00
hkfires
e0381a6ae0 refactor(watcher): extract model summary functions to dedicated file 2025-12-30 22:39:12 +08:00
hkfires
2c01b2ef64 feat(watcher): add Gemini models and OAuth model mappings change detection 2025-12-30 22:39:12 +08:00
Chén Mù
e947266743 Merge pull request #795 from router-for-me/modelmappings
refactor(executor): resolve upstream model at conductor level before execution
2025-12-30 05:31:19 -08:00
Luis Pater
c6b0e85b54 Fixed: #790
fix(gemini): include full text in response output events
2025-12-30 20:44:13 +08:00
hkfires
26efbed05c refactor(executor): remove redundant upstream model parameter from translateRequest 2025-12-30 20:20:42 +08:00
hkfires
96340bf136 refactor(executor): resolve upstream model at conductor level before execution 2025-12-30 19:31:54 +08:00
hkfires
b055e00c1a fix(executor): use upstream model for thinking config and payload translation 2025-12-30 17:49:44 +08:00
sususu
414db44c00 fix(antigravity): parse retry-after delay from 429 response body
When receiving HTTP 429 (Too Many Requests) responses, parse the retry
delay from the response body using parseRetryDelay and populate the
statusErr.retryAfter field. This allows upstream callers to respect
the server's requested retry timing.

Applied to all error paths in Execute, executeClaudeNonStream,
ExecuteStream, CountTokens, and refreshToken functions.
2025-12-30 16:07:32 +08:00
Chén Mù
857c880f99 Merge pull request #785 from router-for-me/gemini
feat(gemini): add per-key model alias support for Gemini provider
2025-12-29 23:32:40 -08:00
hkfires
ce7474d953 feat(cliproxy): propagate thinking support metadata to aliased models 2025-12-30 15:16:54 +08:00
hkfires
70fdd70b84 refactor(cliproxy): extract generic buildConfigModels function for model info generation 2025-12-30 13:35:22 +08:00
hkfires
08ab6a7d77 feat(gemini): add per-key model alias support for Gemini provider 2025-12-30 13:27:57 +08:00
Luis Pater
5c95129884 Merge branch 'router-for-me:main' into main 2025-12-30 11:48:29 +08:00
Luis Pater
9fa2a7e9df Merge pull request #782 from router-for-me/modelmappings
refactor(config): rename model-name-mappings to oauth-model-mappings
2025-12-30 11:40:12 +08:00
hkfires
d443c86620 refactor(config): rename model mapping fields from from/to to name/alias 2025-12-30 11:07:59 +08:00
hkfires
7be3f1c36c refactor(config): rename model-name-mappings to oauth-model-mappings 2025-12-30 11:07:58 +08:00
Luis Pater
f6ab6d97b9 fix(logging): add isDirWritable utility to enhance log dir validation in ConfigureLogOutput 2025-12-30 10:48:25 +08:00
Luis Pater
bc866bac49 fix(logging): refactor ConfigureLogOutput to accept config object and adjust log directory handling 2025-12-30 10:28:25 +08:00
Luis Pater
50e6d845f4 feat(cliproxy): introduce global model name mappings for improved aliasing and routing 2025-12-30 08:13:06 +08:00
Luis Pater
a8cb01819d Merge pull request #772 from soffchen/main
fix: Implement fallback log directory for file logging on read-only system
2025-12-30 02:24:49 +08:00
Luis Pater
1a99cfded4 Merge branch 'router-for-me:main' into main 2025-12-30 00:38:41 +08:00
Luis Pater
530273906b Merge pull request #776 from router-for-me/fix-ag-claude
fix(antigravity): inject required placeholder when properties exist w…
2025-12-30 00:37:01 +08:00
Supra4E8C
06ddf575d9 fix(antigravity): inject required placeholder when properties exist without required 2025-12-29 23:55:59 +08:00
Luis Pater
cf369d4684 Merge branch 'router-for-me:main' into main 2025-12-29 22:41:44 +08:00
hkfires
3099114cbb refactor(api): simplify codex id token claims extraction 2025-12-29 19:48:02 +08:00
Soff
44b63f0767 fix: Return an error if the user home directory cannot be determined for the fallback log path. 2025-12-29 18:46:15 +08:00
Soff Chen
6705d20194 fix: Implement fallback log directory for file logging on read-only systems. 2025-12-29 18:35:48 +08:00
Chén Mù
a38a9c0b0f Merge pull request #770 from router-for-me/api
feat(api): add id token claims extraction for codex auth entries
2025-12-29 00:44:41 -08:00
hkfires
8286caa366 feat(api): add id token claims extraction for codex auth entries 2025-12-29 16:34:16 +08:00
Chén Mù
bd1ec8424d Merge pull request #767 from router-for-me/amp
feat(amp): add per-client upstream API key mapping support
2025-12-28 22:10:11 -08:00
hkfires
225e2c6797 feat(amp): add per-client upstream API key mapping support 2025-12-29 12:26:25 +08:00
Luis Pater
d8fc485513 fix(translators): correct key path for system_instruction.parts in Claude request logic 2025-12-29 11:54:26 +08:00
hkfires
f137eb0ac4 chore: add codex, agents, and opencode dirs to ignore files 2025-12-29 08:42:29 +08:00
Chén Mù
f39a460487 Merge pull request #761 from router-for-me/log
fix(logging): improve request/response capture
2025-12-28 16:13:10 -08:00
Luis Pater
ee171bc563 feat(api): add ManagementTokenRequester interface for management token request endpoints 2025-12-29 02:42:29 +08:00
hkfires
a95428f204 fix(handlers): preserve upstream response logs before duplicate detection 2025-12-28 22:35:36 +08:00
Michael Velbaum
cb3bdffb43 refactor(logging): streamline auth selection debug messages
Reduce duplicate Debugf calls by appending proxy info via an optional suffix and keep the debug-level guard inside the helper.
2025-12-28 16:10:11 +02:00
Michael Velbaum
48f19aab51 refactor(logging): pass request entry into auth selection log
Avoid re-creating the request-scoped log entry in the helper and use a switch for account type dispatch.
2025-12-28 15:51:11 +02:00
Michael Velbaum
48f6d7abdf refactor(logging): dedupe auth selection debug logs
Extract repeated debug logging for selected auth credentials into a helper so execute, count, and stream paths stay consistent.
2025-12-28 15:42:35 +02:00
Michael Velbaum
79fbcb3ec4 fix(logging): quote OAuth account field
Use strconv.Quote when embedding the OAuth account in debug logs so unexpected characters (e.g. quotes) can't break key=value parsing.
2025-12-28 15:32:54 +02:00
Michael Velbaum
0e4148b229 feat(logging): disambiguate OAuth credential selection in debug logs
When multiple OAuth providers share an account email, the existing "Use OAuth" debug lines are ambiguous and hard to correlate with management usage stats. Include provider, auth file, and auth index in the selection log, and only compute these fields when debug logging is enabled to avoid impacting normal request performance.

Before:
[debug] Use OAuth user@example.com for model gemini-3-flash-preview
[debug] Use OAuth user@example.com (project-1234) for model gemini-3-flash-preview

After:
[debug] Use OAuth provider=antigravity auth_file=antigravity-user_example_com.json auth_index=1a2b3c4d5e6f7788 account="user@example.com" for model gemini-3-flash-preview
[debug] Use OAuth provider=gemini-cli auth_file=gemini-user@example.com-project-1234.json auth_index=99aabbccddeeff00 account="user@example.com (project-1234)" for model gemini-3-flash-preview
2025-12-28 15:22:36 +02:00
hkfires
3ca5fb1046 fix(handlers): match raw error text before JSON body for duplicate detection 2025-12-28 19:35:36 +08:00
hkfires
a091d12f4e fix(logging): improve request/response capture 2025-12-28 19:04:31 +08:00
Luis Pater
e3d8d726e6 Merge branch 'router-for-me:main' into main 2025-12-28 15:09:33 +08:00
Luis Pater
457924828a Merge pull request #757 from ben-vargas/fix-thinking-toolchoice-conflict
Fix: disable thinking when tool_choice forces tool use
2025-12-28 14:04:30 +08:00
Ben Vargas
aca2ef6359 Fix: disable thinking when tool_choice forces tool use
Anthropic API does not allow extended thinking when tool_choice is set
to "any" or a specific tool. This was causing 400 errors when using
features like Amp's /handoff command which forces tool_choice.

Added disableThinkingIfToolChoiceForced() that removes thinking config
when incompatible tool_choice is detected, applied to both streaming
and non-streaming paths.

Fixes router-for-me/CLIProxyAPI#630
2025-12-27 16:31:37 -07:00
Luis Pater
ade7194792 feat(management): add generic API call handler to management endpoints 2025-12-28 04:40:32 +08:00
Luis Pater
0f51e73baa Merge branch 'router-for-me:main' into main 2025-12-28 03:07:58 +08:00
Luis Pater
3a436e116a feat(cliproxy): implement model aliasing and hashing for Codex configurations, enhance request routing logic, and normalize Codex model entries 2025-12-28 03:06:51 +08:00
Luis Pater
d06e2dc83c Merge branch 'router-for-me:main' into main 2025-12-28 02:10:16 +08:00
Luis Pater
336867853b Merge pull request #756 from leaph/check-ai-thinking-settings
feat(iflow): add model-specific thinking configs for GLM-4.7 and Mini…
2025-12-28 02:08:27 +08:00
leaph
6403ff4ec4 feat(iflow): add model-specific thinking configs for GLM-4.7 and MiniMax-M2.1
- GLM-4.7: Uses extra_body={"thinking": {"type": "enabled"}, "clear_thinking": false}
- MiniMax-M2.1: Uses reasoning_split=true for OpenAI-style reasoning separation
- Added preserveReasoningContentInMessages() to support re-injection of reasoning
  content in assistant message history for multi-turn conversations
- Added ThinkingSupport to MiniMax-M2.1 model definition
2025-12-27 18:39:15 +01:00
Luis Pater
d222469b44 Update issue templates 2025-12-28 01:22:42 +08:00
Luis Pater
790a17ce98 Merge pull request #70 from router-for-me/plus
v6.6.60
2025-12-28 00:57:14 +08:00
Luis Pater
d473c952fb Merge branch 'main' into plus 2025-12-28 00:56:04 +08:00
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
d35152bbef Merge branch 'router-for-me:main' into main 2025-12-27 22:03:50 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
60936b5185 Merge branch 'router-for-me:main' into main 2025-12-27 03:57:03 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
b7f7b3a1d8 Merge branch 'router-for-me:main' into main 2025-12-27 01:26:33 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
618606966f Merge pull request #68 from router-for-me/plus
v6.6.56
2025-12-26 12:14:42 +08:00
Luis Pater
05f249d77f Merge branch 'main' into plus 2025-12-26 12:14:35 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
9fe6a215e6 Merge branch 'router-for-me:main' into main 2025-12-26 05:10:19 +08:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
cf8b2dcc85 Merge pull request #67 from router-for-me/plus
v6.6.54
2025-12-25 21:07:45 +08:00
Luis Pater
8e24d9dc34 Merge branch 'main' into plus 2025-12-25 21:07:37 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
ed57d82bc1 Merge branch 'router-for-me:main' into main 2025-12-24 23:31:09 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
Luis Pater
7af5a90a0b Merge pull request #65 from router-for-me/plus
v6.6.52
2025-12-24 22:28:45 +08:00
Luis Pater
7551faff79 Merge branch 'main' into plus 2025-12-24 22:27:12 +08:00
Luis Pater
b7409dd2de Merge pull request #706 from router-for-me/log
Log
2025-12-24 22:24:39 +08:00
hkfires
5ba325a8fc refactor(logging): standardize request id formatting and layout 2025-12-24 22:03:07 +08:00
Luis Pater
d502840f91 Merge pull request #695 from NguyenSiTrung/main
feat: add cached token parsing for Gemini , Antigravity API responses
2025-12-24 21:58:55 +08:00
hkfires
99238a4b59 fix(logging): normalize warning level to warn 2025-12-24 21:11:37 +08:00
hkfires
6d43a2ff9a refactor(logging): inline request id in log output 2025-12-24 21:07:18 +08:00
Luis Pater
cdb9c2e6e8 Merge remote-tracking branch 'origin/main' into router-for-me/main 2025-12-24 20:18:53 +08:00
Luis Pater
3faa1ca9af Merge pull request #700 from router-for-me/log
refactor(sdk/auth): rename manager.go to conductor.go
2025-12-24 19:36:24 +08:00
Luis Pater
9d975e0375 feat(models): add support for GLM-4.7 and MiniMax-M2.1 2025-12-24 19:30:57 +08:00
hkfires
2a6d8b78d4 feat(api): add endpoint to retrieve request logs by ID 2025-12-24 19:24:51 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
Luis Pater
6b80ec79a0 Merge pull request #61 from tinyc0der/fix/kiro-tool-results
fix(kiro): Handle tool results correctly in OpenAI format translation
2025-12-24 17:23:41 +08:00
Luis Pater
d3f4783a24 Merge pull request #57 from PancakeZik/my-idc-changes
feat: add AWS Identity Center (IDC) authentication support
2025-12-24 17:20:01 +08:00
Luis Pater
1cb6bdbc87 Merge pull request #62 from router-for-me/plus
v6.6.50
2025-12-24 17:14:05 +08:00
Luis Pater
96ddfc1f24 Merge branch 'main' into plus 2025-12-24 17:13:51 +08:00
TinyCoder
c169b32570 refactor(kiro): Remove unused variables in OpenAI translator
Remove dead code that was never used:
- toolCallIDToName map: built but never read from
- seenToolCallIDs: declared but never populated, only suppressed with _
2025-12-24 15:10:35 +07:00
TinyCoder
36a512fdf2 fix(kiro): Handle tool results correctly in OpenAI format translation
Fix three issues in Kiro OpenAI translator that caused "Improperly formed request"
errors when processing LiteLLM-translated requests with tool_use/tool_result:

1. Skip merging tool role messages in MergeAdjacentMessages() to preserve
   individual tool_call_id fields

2. Track pendingToolResults and attach to the next user message instead of
   only the last message. Create synthetic user message when conversation
   ends with tool results.

3. Insert synthetic user message with tool results before assistant messages
   to maintain proper alternating user/assistant structure. This fixes the case
   where LiteLLM translates Anthropic user messages containing only tool_result
   blocks into tool role messages followed by assistant.

Adds unit tests covering all tool result handling scenarios.
2025-12-24 15:10:35 +07:00
hkfires
26fbb77901 refactor(sdk/auth): rename manager.go to conductor.go 2025-12-24 15:21:03 +08:00
NguyenSiTrung
a277302262 Merge remote-tracking branch 'upstream/main' 2025-12-24 10:54:09 +07:00
NguyenSiTrung
969c1a5b72 refactor: extract parseGeminiFamilyUsageDetail helper to reduce duplication 2025-12-24 10:22:31 +07:00
NguyenSiTrung
872339bceb feat: add cached token parsing for Gemini API responses 2025-12-24 10:20:11 +07:00
Luis Pater
5dc0dbc7aa Merge pull request #697 from Cubence-com/main
docs(readme): add Cubence sponsor and fix PackyCode link
2025-12-24 11:19:32 +08:00
Luis Pater
ee6fc4e8a1 Merge pull request #60 from router-for-me/pr-59-resolve-conflicts
v6.6.50(解决 #59 冲突)
2025-12-24 11:12:09 +08:00
Luis Pater
8fee16aecd Merge PR #59: v6.6.50 (resolve conflicts) 2025-12-24 11:06:10 +08:00
Luis Pater
2b7ba54a2f Merge pull request #688 from router-for-me/feature/request-id-tracking
feat(logging): implement request ID tracking and propagation
2025-12-24 10:54:13 +08:00
hkfires
007c3304f2 feat(logging): scope request ID tracking to AI API endpoints 2025-12-24 09:17:09 +08:00
hkfires
e76ba0ede9 feat(logging): implement request ID tracking and propagation 2025-12-24 08:32:17 +08:00
Luis Pater
c06ac07e23 Merge pull request #686 from ajkdrag/main
feat: regex support for model-mappings
2025-12-24 04:37:44 +08:00
Luis Pater
e592a57458 Merge branch 'router-for-me:main' into main 2025-12-24 04:25:06 +08:00
Luis Pater
66769ec657 fix(translators): update role from tool to user in Gemini and Gemini-CLI requests 2025-12-24 04:24:07 +08:00
Luis Pater
f413feec61 refactor(handlers): streamline error and data channel handling in streaming logic
Improved consistency across OpenAI, Claude, and Gemini handlers by replacing initial `select` statement with a `for` loop for better readability and error-handling robustness.
2025-12-24 04:07:24 +08:00
Luis Pater
2e538e3486 Merge pull request #661 from jroth1111/fix/streaming-bootstrap-forwarder
fix: improve streaming bootstrap and forwarding
2025-12-24 03:51:40 +08:00
Luis Pater
9617a7b0d6 Merge pull request #621 from dacsang97/fix/antigravity-prompt-caching
Antigravity Prompt Caching Fix
2025-12-24 03:50:25 +08:00
Luis Pater
7569320770 Merge branch 'dev' into fix/antigravity-prompt-caching 2025-12-24 03:49:46 +08:00
Fetters
8d25cf0d75 fix(readme): update PackyCode sponsorship link and remove redundant tbody 2025-12-23 23:44:40 +08:00
Fetters
64e85e7019 docs(readme): add Cubence sponsor 2025-12-23 23:30:57 +08:00
Luis Pater
a862984dca Merge pull request #58 from router-for-me/plus
v6.6.48
2025-12-23 22:34:15 +08:00
Luis Pater
f0365f0465 Merge branch 'main' into plus 2025-12-23 22:34:08 +08:00
Luis Pater
6d1e20e940 fix(claude_executor): update header logic for API key handling
Refined header assignment to use `x-api-key` for Anthropic API requests, ensuring correct authorization behavior based on request attributes and URL validation.
2025-12-23 22:30:25 +08:00
altamash
0c0aae1eac Robust change detection: replaced string concat with struct-based compare in hasModelMappingsChanged; removed boolTo01.
•  Performance: pre-allocate map and regex slice capacities in UpdateMappings.
   •  Verified with amp module tests (all passing)
2025-12-23 18:52:28 +05:30
altamash
5dcf7cb846 feat: regex support for model-mappings 2025-12-23 18:41:58 +05:30
Joao
349b2ba3af refactor: improve error handling and code quality
- Handle errors in promptInput instead of ignoring them
- Improve promptSelect to provide feedback on invalid input and re-prompt
- Use sentinel errors (ErrAuthorizationPending, ErrSlowDown) instead of
  string-based error checking with strings.Contains
- Move hardcoded x-amz-user-agent header to idcAmzUserAgent constant

Addresses code review feedback from Gemini Code Assist.
2025-12-23 10:20:14 +00:00
Joao
98db5aabd0 feat: persist refreshed IDC tokens to auth file
Add persistRefreshedAuth function to write refreshed tokens back to the
auth file after inline token refresh. This prevents repeated token
refreshes on every request when the token expires.

Changes:
- Add persistRefreshedAuth() to kiro_executor.go
- Call persist after all token refresh paths (401, 403, pre-request)
- Remove unused log import from sdk/auth/kiro.go
2025-12-23 10:00:14 +00:00
Luis Pater
e52b542e22 Merge pull request #684 from packyme/main
docs(readme): add PackyCode sponsor
2025-12-23 17:19:25 +08:00
SmallL-U
8f6abb8a86 fix(readme): correct closing tbody tag 2025-12-23 17:17:57 +08:00
SmallL-U
ed8eaae964 docs(readme): add PackyCode sponsor 2025-12-23 17:11:34 +08:00
Joao
7fd98f3556 feat: add IDC auth support with Kiro IDE headers 2025-12-23 08:18:10 +00:00
Luis Pater
e8de87ee90 Merge branch 'router-for-me:main' into main 2025-12-23 08:48:29 +08:00
Luis Pater
4e572ec8b9 fix(translators): handle string system instructions in Claude translators
Updated Antigravity, Gemini, and Gemini-CLI translators to process `systemResult` of type `string` for system instructions. Ensures properly formatted JSON with dynamic content assignment.
2025-12-23 08:44:36 +08:00
Luis Pater
6c7f18c448 Merge branch 'router-for-me:main' into main 2025-12-23 03:50:35 +08:00
Luis Pater
24bc9cba67 Fixed: #639
fix(antigravity): validate function arguments before serialization

Ensure `function.arguments` is a valid JSON before setting raw bytes, fallback to setting as parameterized content if invalid.
2025-12-23 03:49:45 +08:00
Luis Pater
97356b1a04 Merge branch 'router-for-me:main' into main 2025-12-23 03:17:44 +08:00
Luis Pater
1084b53fba Fixed: #655
refactor(antigravity): clean up tool key filtering and improve signature caching logic
2025-12-23 03:16:51 +08:00
Luis Pater
b1aecc2bf1 Merge branch 'router-for-me:main' into main 2025-12-23 02:49:37 +08:00
Luis Pater
83b90e106f refactor(antigravity): add sandbox URL constant and update base URLs routine 2025-12-23 02:47:56 +08:00
Luis Pater
f52114dab2 Merge branch 'router-for-me:main' into main 2025-12-23 02:26:33 +08:00
Luis Pater
5106caf641 Fixed: #654
feat: handle array input for system instructions in translators

Enhanced Gemini, Gemini-CLI, and Antigravity translators to process array content for system instructions. Adds support for assigning roles and handling multiple content parts dynamically.
2025-12-23 02:24:26 +08:00
Luis Pater
12370ee84e Merge branch 'router-for-me:main' into main 2025-12-22 22:53:29 +08:00
Luis Pater
b84ccc6e7a feat: add unit tests for routing strategies and implement dynamic selector updates
Added comprehensive tests for `FillFirstSelector` and `RoundRobinSelector` to ensure proper behavior, including deterministic, cyclical, and concurrent scenarios. Introduced dynamic routing strategy updates in `service.go`, normalizing strategies and seamlessly switching between `fill-first` and `round-robin`. Updated `Manager` to support selector changes via the new `SetSelector` method.
2025-12-22 22:52:23 +08:00
Luis Pater
e19ddb53e7 Merge pull request #663 from jroth1111/feat/fill-first-selector
feat: add fill-first routing strategy
2025-12-22 22:26:32 +08:00
gwizz
5bf89dd757 fix: keep streaming defaults legacy-safe 2025-12-23 00:53:18 +11:00
gwizz
2a0100b2d6 docs: add routing strategy example 2025-12-23 00:39:18 +11:00
gwizz
4442574e53 fix: stop streaming loop on context cancel 2025-12-23 00:37:55 +11:00
gwizz
c020fa60d0 fix: keep round-robin as default routing 2025-12-22 23:39:41 +11:00
gwizz
b078be4613 feat: add fill-first routing strategy 2025-12-22 23:38:10 +11:00
gwizz
71a6dffbb6 fix: improve streaming bootstrap and forwarding 2025-12-22 23:34:23 +11:00
Evan Nguyen
24e8e20b59 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-21 19:43:24 +07:00
Evan Nguyen
a87f09bad2 feat(antigravity): add session ID generation and mutex for random source 2025-12-21 17:50:41 +07:00
evann
bc6c4cdbfc feat(antigravity): add logging for cached token setting errors in responses 2025-12-19 16:49:50 +07:00
evann
404546ce93 refactor(antigravity): regarding production endpoint caching 2025-12-19 16:36:54 +07:00
evann
6dd1cf1dd6 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-19 16:34:28 +07:00
evann
9058d406a3 feat(antigravity): enhance prompt caching support and update agent version 2025-12-19 16:33:41 +07:00
이대희
31bd90c748 feature: Improves Amp client compatibility
Ensures compatibility with the Amp client by suppressing
"thinking" blocks when "tool_use" blocks are also present in
the response.

The Amp client has issues rendering both types of blocks
simultaneously. This change filters out "thinking" blocks in
such cases, preventing rendering problems.
2025-12-19 08:18:27 +09:00
Muzhen Gaming
0b834fcb54 fix(translator): preserve built-in tools across openai<->responses
- Pass through non-function tool definitions like web_search

- Translate tool_choice for built-in tools and function tools

- Add regression tests for built-in tool passthrough
2025-12-15 21:18:54 +08:00
253 changed files with 33810 additions and 7774 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/*
README.md
README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE
# Runtime data folders (should be mounted as volumes)
@@ -25,10 +23,14 @@ config.yaml
# Development/editor
bin/*
.claude/*
.vscode/*
.claude/*
.codex/*
.gemini/*
.serena/*
.agent/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*

View File

@@ -7,6 +7,13 @@ assignees: ''
---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug**
A clear and concise description of what the bug is.

View File

@@ -10,13 +10,11 @@ env:
DOCKERHUB_REPO: eceasy/cli-proxy-api-plus
jobs:
docker:
docker_amd64:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
@@ -29,18 +27,113 @@ jobs:
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push
- name: Build and push (amd64)
uses: docker/build-push-action@v6
with:
context: .
platforms: |
linux/amd64
linux/arm64
platforms: linux/amd64
push: true
build-args: |
VERSION=${{ env.VERSION }}
COMMIT=${{ env.COMMIT }}
BUILD_DATE=${{ env.BUILD_DATE }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}
${{ env.DOCKERHUB_REPO }}:latest-amd64
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-amd64
docker_arm64:
runs-on: ubuntu-24.04-arm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (arm64)
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/arm64
push: true
build-args: |
VERSION=${{ env.VERSION }}
COMMIT=${{ env.COMMIT }}
BUILD_DATE=${{ env.BUILD_DATE }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest-arm64
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-arm64
docker_manifest:
runs-on: ubuntu-latest
needs:
- docker_amd64
- docker_arm64
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Create and push multi-arch manifests
run: |
docker buildx imagetools create \
--tag "${DOCKERHUB_REPO}:latest" \
"${DOCKERHUB_REPO}:latest-amd64" \
"${DOCKERHUB_REPO}:latest-arm64"
docker buildx imagetools create \
--tag "${DOCKERHUB_REPO}:${VERSION}" \
"${DOCKERHUB_REPO}:${VERSION}-amd64" \
"${DOCKERHUB_REPO}:${VERSION}-arm64"
- name: Cleanup temporary tags
continue-on-error: true
env:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
run: |
set -euo pipefail
namespace="${DOCKERHUB_REPO%%/*}"
repo_name="${DOCKERHUB_REPO#*/}"
token="$(
curl -fsSL \
-H 'Content-Type: application/json' \
-d "{\"username\":\"${DOCKERHUB_USERNAME}\",\"password\":\"${DOCKERHUB_TOKEN}\"}" \
'https://hub.docker.com/v2/users/login/' \
| python3 -c 'import json,sys; print(json.load(sys.stdin)["token"])'
)"
delete_tag() {
local tag="$1"
local url="https://hub.docker.com/v2/repositories/${namespace}/${repo_name}/tags/${tag}/"
local http_code
http_code="$(curl -sS -o /dev/null -w "%{http_code}" -X DELETE -H "Authorization: JWT ${token}" "${url}" || true)"
if [ "${http_code}" = "204" ] || [ "${http_code}" = "404" ]; then
echo "Docker Hub tag removed (or missing): ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
return 0
fi
echo "Docker Hub tag delete failed: ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
return 0
}
delete_tag "latest-amd64"
delete_tag "latest-arm64"
delete_tag "${VERSION}-amd64"
delete_tag "${VERSION}-arm64"

12
.gitignore vendored
View File

@@ -12,11 +12,15 @@ bin/*
logs/*
conv/*
temp/*
refs/*
# Storage backends
pgstore/*
gitstore/*
objectstore/*
# Static assets
static/*
refs/*
# Authentication data
auths/*
@@ -30,14 +34,20 @@ GEMINI.md
# Tooling metadata
.vscode/*
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.bmad/*
_bmad/*
_bmad-output/*
.mcp/cache/
# macOS
.DS_Store
._*
*.bak

View File

@@ -13,11 +13,87 @@ The Plus release stays in lockstep with the mainline features.
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
## New Features (Plus Enhanced)
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
- **Device Fingerprint**: Device fingerprint generation for enhanced security
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
- **Usage Checker**: Real-time usage monitoring and quota management
- **Model Converter**: Unified model name conversion across providers
- **UTF-8 Stream Processing**: Improved streaming response handling
## Kiro Authentication
### Web-based OAuth Login
Access the Kiro OAuth web interface at:
```
http://your-server:8080/v0/oauth/kiro
```
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
- AWS Builder ID login
- AWS Identity Center (IDC) login
- Token import from Kiro IDE
## Quick Deployment with Docker
### One-Command Deployment
```bash
# Create deployment directory
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# Create docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# Download example config
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# Pull and start
docker compose pull && docker compose up -d
```
### Configuration
Edit `config.yaml` before starting:
```yaml
# Basic configuration example
server:
port: 8317
# Add your provider configurations here
```
### Update to Latest Version
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## Contributing
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
If you need to submit any non-third-party provider changes, please open them against the mainline repository.
If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
## License

View File

@@ -13,11 +13,87 @@
- 新增 GitHub Copilot 支持OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
## 新增功能 (Plus 增强版)
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
- **请求限流器**: 内置请求限流,防止 API 滥用
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
- **监控指标**: 请求指标收集,用于监控和调试
- **设备指纹**: 设备指纹生成,增强安全性
- **冷却管理**: 智能冷却机制,应对 API 速率限制
- **用量检查器**: 实时用量监控和配额管理
- **模型转换器**: 跨供应商的统一模型名称转换
- **UTF-8 流处理**: 改进的流式响应处理
## Kiro 认证
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:
```
http://your-server:8080/v0/oauth/kiro
```
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
- AWS Builder ID 登录
- AWS Identity Center (IDC) 登录
- 从 Kiro IDE 导入令牌
## Docker 快速部署
### 一键部署
```bash
# 创建部署目录
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# 创建 docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# 下载示例配置
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# 拉取并启动
docker compose pull && docker compose up -d
```
### 配置说明
启动前请编辑 `config.yaml`
```yaml
# 基本配置示例
server:
port: 8317
# 在此添加你的供应商配置
```
### 更新到最新版本
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## 贡献
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
如果需要提交任何非第三方供应商支持的 Pull Request请提交到主线版本。
如果需要提交任何非第三方供应商支持的 Pull Request请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
## 许可证

BIN
assets/cubence.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
assets/packycode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

View File

@@ -17,6 +17,7 @@ import (
"github.com/joho/godotenv"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -74,6 +75,7 @@ func main() {
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
@@ -96,6 +98,7 @@ func main() {
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
@@ -434,7 +437,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err)
return
}
@@ -454,7 +457,8 @@ func main() {
// Create login options to be used in authentication flows.
options := &cmd.LoginOptions{
NoBrowser: noBrowser,
NoBrowser: noBrowser,
CallbackPort: oauthCallbackPort,
}
// Register the shared token store once so all components use the same persistence backend.
@@ -530,6 +534,13 @@ func main() {
}
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
// 初始化并启动 Kiro token 后台刷新
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
}

View File

@@ -35,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
# Enable debug logging
debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support)
@@ -71,9 +75,25 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
nonstream-keepalive-interval: 0
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# When true, enable official Codex instructions injection for Codex API requests.
# When false (default), CodexInstructionsForModel returns immediately without modification.
codex-instructions-enabled: false
# Gemini API keys
# gemini-api-key:
# - api-key: "AIzaSy...01"
@@ -82,6 +102,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
@@ -97,6 +120,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
@@ -114,12 +140,21 @@ ws-auth: false
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
# cloak: # optional: request cloaking for non-Claude-Code clients
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
# # "always": always apply cloaking
# # "never": never apply cloaking
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
# # true: strip all user system messages, keep only Claude Code prompt
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
@@ -155,9 +190,9 @@ ws-auth: false
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.0-flash" # upstream model name
# - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-1.5-pro"
# - name: "gemini-2.5-pro"
# alias: "vertex-pro"
# Amp Integration
@@ -166,6 +201,18 @@ ws-auth: false
# upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false)
@@ -175,14 +222,68 @@ ws-auth: false
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings:
# - from: "claude-opus-4.5" # Model requested by Amp CLI
# to: "claude-sonnet-4" # Route to this available model instead
# - from: "gpt-5"
# to: "gemini-2.5-pro"
# - from: "claude-3-opus-20240229"
# to: "claude-3-5-sonnet-20241022"
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "claude-sonnet-4-5-20250929"
# to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
#oauth-model-alias:
# antigravity:
# - name: "rev19-uic3-1p"
# alias: "gemini-2.5-computer-use-preview-10-2025"
# - name: "gemini-3-pro-image"
# alias: "gemini-3-pro-image-preview"
# - name: "gemini-3-pro-high"
# alias: "gemini-3-pro-preview"
# - name: "gemini-3-flash"
# alias: "gemini-3-flash-preview"
# - name: "claude-sonnet-4-5"
# alias: "gemini-claude-sonnet-4-5"
# - name: "claude-sonnet-4-5-thinking"
# alias: "gemini-claude-sonnet-4-5-thinking"
# - name: "claude-opus-4-5-thinking"
# alias: "gemini-claude-opus-4-5-thinking"
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-preview"
# alias: "g3p"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# kiro:
# - name: "kiro-claude-opus-4-5"
# alias: "op45"
# github-copilot:
# - name: "gpt-5"
# alias: "copilot-gpt5"
# OAuth provider excluded models
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
# oauth-excluded-models:
# gemini-cli:
# - "gemini-2.5-pro" # exclude specific models (exact match)
@@ -203,18 +304,41 @@ ws-auth: false
# - "vision-model"
# iflow:
# - "tstars2.0"
# kiro:
# - "kiro-claude-haiku-4-5"
# github-copilot:
# - "raptor-mini"
# Optional payload configuration
# payload:
# default: # Default rules only set parameters when they are missing in the payload.
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> value
# "generationConfig.thinkingConfig.thinkingBudget": 32768
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
# override: # Override rules always set parameters, overwriting any existing values.
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> value
# "reasoning.effort": "high"
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
# filter: # Filter rules remove specified parameters from the payload.
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON paths (gjson/sjson syntax) to remove from the payload
# - "generationConfig.thinkingConfig.thinkingBudget"
# - "generationConfig.responseJsonSchema"

View File

@@ -5,9 +5,115 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Exit immediately if a command exits with a non-zero status.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
if [[ "${1:-}" == "--with-usage" ]]; then
WITH_USAGE=true
export_stats_api_secret
fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
echo "1) Run using Pre-built Image (Recommended)"
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -38,16 +151,25 @@ case "$choice" in
# Build and start the services with a local-only image tag
export CLI_PROXY_IMAGE="cli-proxy-api:local"
echo "Building the Docker image..."
docker compose build \
--build-arg VERSION="${VERSION}" \
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -55,4 +177,4 @@ case "$choice" in
echo "Invalid choice. Please enter 1 or 2."
exit 1
;;
esac
esac

View File

@@ -22,7 +22,7 @@ services:
- "51121:51121"
- "11451:11451"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
restart: unless-stopped

View File

@@ -14,6 +14,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
httpReq.Header.Set("Content-Type", "application/json")
// Inject credentials via PrepareRequest hook.
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return clipexec.Response{}, errPrep
}
resp, errDo := client.Do(httpReq)
if errDo != nil {
@@ -130,13 +133,28 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
// Best-effort close; log if needed in real projects.
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
}
}()
body, _ := io.ReadAll(resp.Body)
return clipexec.Response{Payload: body}, nil
}
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("myprov executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return nil, errPrep
}
client := buildHTTPClient(a)
return client.Do(httpReq)
}
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("count tokens not implemented")
}
@@ -199,8 +217,8 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
panic(err)
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
panic(errRun)
}
_ = os.Stderr // keep os import used (demo only)
_ = time.Second

View File

@@ -0,0 +1,140 @@
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
// to execute arbitrary HTTP requests with provider credentials injected.
//
// This example registers a minimal custom executor that injects an Authorization
// header from auth.Attributes["api_key"], then performs two requests against
// httpbin.org to show the injected headers.
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
const providerKey = "echo"
// EchoExecutor is a minimal provider implementation for demonstration purposes.
type EchoExecutor struct{}
func (EchoExecutor) Identifier() string { return providerKey }
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
if req == nil || auth == nil {
return nil
}
if auth.Attributes != nil {
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
}
return nil
}
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("echo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
return nil, errPrep
}
return http.DefaultClient.Do(httpReq)
}
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
return nil, errors.New("echo executor: Refresh not implemented")
}
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
}
func main() {
log.SetLevel(log.InfoLevel)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
core := coreauth.NewManager(nil, nil, nil)
core.RegisterExecutor(EchoExecutor{})
auth := &coreauth.Auth{
ID: "demo-echo",
Provider: providerKey,
Attributes: map[string]string{
"api_key": "demo-api-key",
},
}
// Example 1: Build a prepared request and execute it using your own http.Client.
reqPrepared, errReqPrepared := core.NewHttpRequest(
ctx,
auth,
http.MethodGet,
"https://httpbin.org/anything",
nil,
http.Header{"X-Example": []string{"prepared"}},
)
if errReqPrepared != nil {
panic(errReqPrepared)
}
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
if errDoPrepared != nil {
panic(errDoPrepared)
}
defer func() {
if errClose := respPrepared.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
if errReadPrepared != nil {
panic(errReadPrepared)
}
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
rawBody := []byte(`{"hello":"world"}`)
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
if errRawReq != nil {
panic(errRawReq)
}
rawReq.Header.Set("Content-Type", "application/json")
rawReq.Header.Set("X-Example", "executed")
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
if errDoExec != nil {
panic(errDoExec)
}
defer func() {
if errClose := respExec.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyExec, errReadExec := io.ReadAll(respExec.Body)
if errReadExec != nil {
panic(errReadExec)
}
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
}

5
go.mod
View File

@@ -13,6 +13,7 @@ require (
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
github.com/refraction-networking/utls v1.8.2
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/sirupsen/logrus v1.9.3
github.com/tidwall/gjson v1.18.0
@@ -21,6 +22,7 @@ require (
golang.org/x/crypto v0.45.0
golang.org/x/net v0.47.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.18.0
golang.org/x/term v0.37.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
@@ -39,6 +41,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -68,8 +71,8 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect

6
go.sum
View File

@@ -35,6 +35,8 @@ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -120,6 +122,8 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
@@ -157,6 +161,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=

View File

@@ -0,0 +1,737 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
const (
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityOAuthClientSecret = "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"
)
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON (supports both application/json and application/cbor):
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response (returned with HTTP 200 when the APICall itself succeeds):
//
// Format matches request Content-Type (application/json or application/cbor)
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
// Detect content type
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
isCBOR := strings.Contains(contentType, "application/cbor")
var body apiCallRequest
// Parse request body based on content type
if isCBOR {
rawBody, errRead := io.ReadAll(c.Request.Body)
if errRead != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
return
}
} else {
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
}
// Return response in the same format as the request
if isCBOR {
cborData, errMarshal := cbor.Marshal(response)
if errMarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
return
}
c.Data(http.StatusOK, "application/cbor", cborData)
} else {
c.JSON(http.StatusOK, response)
}
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
if provider == "antigravity" {
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata := auth.Metadata
if len(metadata) == 0 {
return "", fmt.Errorf("antigravity oauth metadata missing")
}
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
return current, nil
}
refreshToken := stringValue(metadata, "refresh_token")
if refreshToken == "" {
return "", fmt.Errorf("antigravity refresh token missing")
}
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
form := url.Values{}
form.Set("client_id", antigravityOAuthClientID)
form.Set("client_secret", antigravityOAuthClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if errReq != nil {
return "", errReq
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", errRead
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return "", errUnmarshal
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
now := time.Now()
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn > 0 {
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
}
auth.Metadata["type"] = "antigravity"
if h != nil && h.authManager != nil {
auth.LastRefreshedAt = now
auth.UpdatedAt = now
_, _ = h.authManager.Update(ctx, auth)
}
return strings.TrimSpace(tokenResp.AccessToken), nil
}
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
// Refresh a bit early to avoid requests racing token expiry.
const skew = 30 * time.Second
if metadata == nil {
return true
}
if expStr, ok := metadata["expired"].(string); ok {
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
return !ts.After(time.Now().Add(skew))
}
}
expiresIn := int64Value(metadata["expires_in"])
timestampMs := int64Value(metadata["timestamp"])
if expiresIn > 0 && timestampMs > 0 {
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
return !exp.After(time.Now().Add(skew))
}
return true
}
func int64Value(raw any) int64 {
switch typed := raw.(type) {
case int:
return int64(typed)
case int32:
return int64(typed)
case int64:
return typed
case uint:
return int64(typed)
case uint32:
return int64(typed)
case uint64:
if typed > uint64(^uint64(0)>>1) {
return 0
}
return int64(typed)
case float32:
return int64(typed)
case float64:
return int64(typed)
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i
}
case string:
if s := strings.TrimSpace(typed); s != "" {
if i, errParse := json.Number(s).Int64(); errParse == nil {
return i
}
}
}
return 0
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}

View File

@@ -0,0 +1,149 @@
package management
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
)
func TestAPICall_CBOR_Support(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a test handler
h := &Handler{}
// Create test request data
reqData := apiCallRequest{
Method: "GET",
URL: "https://httpbin.org/get",
Header: map[string]string{
"User-Agent": "test-client",
},
}
t.Run("JSON request and response", func(t *testing.T) {
// Marshal request as JSON
jsonData, err := json.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
req.Header.Set("Content-Type", "application/json")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
t.Errorf("Expected JSON response, got: %s", contentType)
}
})
t.Run("CBOR request and response", func(t *testing.T) {
// Marshal request as CBOR
cborData, err := cbor.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal CBOR: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
req.Header.Set("Content-Type", "application/cbor")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
t.Errorf("Expected CBOR response, got: %s", contentType)
}
// Try to decode CBOR response
if w.Code == http.StatusOK {
var response apiCallResponse
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Errorf("Failed to unmarshal CBOR response: %v", err)
} else {
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
}
}
})
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
// Test data
testReq := apiCallRequest{
Method: "POST",
URL: "https://example.com/api",
Header: map[string]string{
"Authorization": "Bearer $TOKEN$",
"Content-Type": "application/json",
},
Data: `{"key":"value"}`,
}
// Encode to CBOR
cborData, err := cbor.Marshal(testReq)
if err != nil {
t.Fatalf("Failed to marshal to CBOR: %v", err)
}
// Decode from CBOR
var decoded apiCallRequest
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
}
// Verify fields
if decoded.Method != testReq.Method {
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
}
if decoded.URL != testReq.URL {
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
}
if decoded.Data != testReq.Data {
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
}
if len(decoded.Header) != len(testReq.Header) {
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
}
})
}
func contains(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
}

View File

@@ -0,0 +1,173 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
}
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -22,8 +23,10 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
@@ -234,14 +237,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
log.Infof("callback forwarder on port %d stopped", port)
}
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured")
@@ -431,9 +426,52 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
log.WithError(err).Warnf("failed to stat auth file %s", path)
}
}
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
return entry
}
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil {
result["chatgpt_subscription_active_start"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil {
result["chatgpt_subscription_active_until"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string {
if auth == nil {
return ""
@@ -708,6 +746,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
return err
}
// PatchAuthFileStatus toggles the disabled state of an auth file
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Disabled *bool `json:"disabled"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Disabled == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
// Update disabled state
targetAuth.Disabled = *req.Disabled
if *req.Disabled {
targetAuth.Status = coreauth.StatusDisabled
targetAuth.StatusMessage = "disabled via management API"
} else {
targetAuth.Status = coreauth.StatusActive
targetAuth.StatusMessage = ""
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
@@ -874,67 +978,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens (replicate logic using updated redirect_uri)
// Extract client_id from the modified auth URL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Build request
bodyMap := map[string]any{
"code": code,
"state": state,
"grant_type": "authorization_code",
"client_id": clientID,
"redirect_uri": "http://localhost:54545/callback",
"code_verifier": pkceCodes.CodeVerifier,
}
bodyJSON, _ := json.Marshal(bodyMap)
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
// Exchange code for tokens using internal auth service
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
if errExchange != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
return
}
var tResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
SetOAuthSessionError(state, "Failed to parse token response")
return
}
bundle := &claude.ClaudeAuthBundle{
TokenData: claude.ClaudeTokenData{
AccessToken: tResp.AccessToken,
RefreshToken: tResp.RefreshToken,
Email: tResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
@@ -974,17 +1025,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
fmt.Println("Initializing Google authentication...")
// OAuth2 configuration (mirrors internal/auth/gemini)
// OAuth2 configuration using exported constants from internal/auth/gemini
conf := &oauth2.Config{
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
ClientSecret: "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",
RedirectURL: "http://localhost:8085/oauth2callback",
Scopes: []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
ClientID: geminiAuth.ClientID,
ClientSecret: geminiAuth.ClientSecret,
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
Scopes: geminiAuth.Scopes,
Endpoint: google.Endpoint,
}
// Build authorization URL and return it immediately
@@ -1106,13 +1153,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ifToken["client_secret"] = "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"
ifToken["scopes"] = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
ifToken["client_id"] = geminiAuth.ClientID
ifToken["client_secret"] = geminiAuth.ClientSecret
ifToken["scopes"] = geminiAuth.Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := geminiAuth.GeminiTokenStorage{
@@ -1299,74 +1342,34 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
log.Debug("Authorization code received, exchanging for tokens...")
// Extract client_id from authURL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Exchange code for tokens with redirect equal to mgmtRedirect
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"code": {code},
"redirect_uri": {"http://localhost:1455/auth/callback"},
"code_verifier": {pkceCodes.CodeVerifier},
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
// Exchange code for tokens using internal auth service
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
if errExchange != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
SetOAuthSessionError(state, "Failed to parse token response")
log.Errorf("failed to parse token response: %v", errU)
return
}
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
// Extract additional info for filename generation
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
planType := ""
hashAccountID := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
}
// Build bundle compatible with existing storage
bundle := &codex.CodexAuthBundle{
TokenData: codex.CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
if accountID := claims.GetAccountID(); accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
// Create token storage and persist
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
record := &coreauth.Auth{
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
ID: fileName,
Provider: "codex",
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": tokenStorage.Email,
@@ -1392,23 +1395,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
const (
antigravityCallbackPort = 51121
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
ctx := context.Background()
fmt.Println("Initializing Antigravity authentication...")
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
@@ -1416,17 +1408,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return
}
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
authURL := authSvc.BuildAuthURL(state, redirectURI)
RegisterOAuthSession(state, "antigravity")
@@ -1440,7 +1423,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
@@ -1449,7 +1432,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
@@ -1489,93 +1472,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
time.Sleep(500 * time.Millisecond)
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
form := url.Values{}
form.Set("code", authCode)
form.Set("client_id", antigravityClientID)
form.Set("client_secret", antigravityClientSecret)
form.Set("redirect_uri", redirectURI)
form.Set("grant_type", "authorization_code")
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest)
SetOAuthSessionError(state, "Failed to build token request")
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo)
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
if errToken != nil {
log.Errorf("Failed to exchange token: %v", errToken)
SetOAuthSessionError(state, "Failed to exchange token")
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
log.Error("antigravity: token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode)
SetOAuthSessionError(state, "Failed to parse token response")
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
log.Errorf("Failed to fetch user info: %v", errInfo)
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
email := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errInfoReq != nil {
log.Errorf("Failed to build user info request: %v", errInfoReq)
SetOAuthSessionError(state, "Failed to build user info request")
return
}
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo)
SetOAuthSessionError(state, "Failed to execute user info request")
return
}
defer func() {
if errClose := infoResp.Body.Close(); errClose != nil {
log.Errorf("antigravity user info close error: %v", errClose)
}
}()
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
var infoPayload struct {
Email string `json:"email"`
}
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
email = strings.TrimSpace(infoPayload.Email)
}
} else {
bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
return
}
email = strings.TrimSpace(email)
if email == "" {
log.Error("antigravity: user info returned empty email")
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
projectID := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
@@ -1600,7 +1526,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email)
fileName := antigravity.CredentialFileName(email)
label := strings.TrimSpace(email)
if label == "" {
label = "antigravity"
@@ -1664,7 +1590,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
// Create token storage
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli())
record := &coreauth.Auth{
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
Provider: "qwen",
@@ -1769,7 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
tokenStorage := authSvc.CreateTokenStorage(tokenData)
identifier := strings.TrimSpace(tokenStorage.Email)
if identifier == "" {
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
identifier = fmt.Sprintf("%d", time.Now().UnixMilli())
tokenStorage.Email = identifier
}
record := &coreauth.Auth{
@@ -1800,6 +1726,89 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitHubToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing GitHub Copilot authentication...")
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
deviceCode, err := deviceClient.RequestDeviceCode(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
tokenData, errPoll := deviceClient.PollForToken(ctx, deviceCode)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPoll)
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
username = "github-user"
}
tokenStorage := &copilot.CopilotTokenStorage{
AccessToken: tokenData.AccessToken,
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"username": username,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authURL,
"state": state,
"user_code": userCode,
"verification_uri": authURL,
})
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
@@ -1854,15 +1863,17 @@ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
fileName := iflowauth.SanitizeIFlowFileName(email)
if fileName == "" {
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
} else {
fileName = fmt.Sprintf("iflow-%s", fileName)
}
tokenStorage.Email = email
timestamp := time.Now().Unix()
record := &coreauth.Auth{
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
ID: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Provider: "iflow",
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
FileName: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Storage: tokenStorage,
Metadata: map[string]any{
"email": email,
@@ -2069,7 +2080,20 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// For free users, use backend project ID for preview model access
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
finalProjectID = responseProjectID
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}

View File

@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
}
// LogsMaxTotalSizeMB
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
}
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 0
}
h.cfg.LogsMaxTotalSizeMB = value
h.persist(c)
}
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// ForceModelPrefix
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
}
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
}
func normalizeRoutingStrategy(strategy string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(strategy))
switch normalized {
case "", "round-robin", "roundrobin", "rr":
return "round-robin", true
case "fill-first", "fillfirst", "ff":
return "fill-first", true
default:
return "", false
}
}
// RoutingStrategy
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
if !ok {
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
return
}
c.JSON(200, gin.H{"strategy": strategy})
}
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
var body struct {
Value *string `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
normalized, ok := normalizeRoutingStrategy(*body.Value)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
return
}
h.cfg.Routing.Strategy = normalized
h.persist(c)
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {

View File

@@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
c.JSON(400, gin.H{"error": "missing name or index"})
}
// vertex-api-key: []VertexCompatKey
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
}
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.VertexCompatKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.VertexCompatKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
}
h.cfg.VertexCompatAPIKey = arr
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
type vertexCompatPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
Models *[]config.VertexCompatModel `json:"models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *vertexCompatPatch `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.VertexCompatAPIKey {
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.VertexCompatAPIKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.Models != nil {
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
}
normalizeVertexCompatKey(&entry)
h.cfg.VertexCompatAPIKey[targetIndex] = entry
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
@@ -572,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
h.persist(c)
}
// oauth-model-alias: map[string][]OAuthModelAlias
func (h *Handler) GetOAuthModelAlias(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)})
}
func (h *Handler) PutOAuthModelAlias(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.OAuthModelAlias
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.OAuthModelAlias `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Aliases []config.OAuthModelAlias `json:"aliases"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
channelRaw := ""
if body.Channel != nil {
channelRaw = *body.Channel
} else if body.Provider != nil {
channelRaw = *body.Provider
}
channel := strings.ToLower(strings.TrimSpace(channelRaw))
if channel == "" {
c.JSON(400, gin.H{"error": "invalid channel"})
return
}
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelAlias == nil {
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
h.cfg.OAuthModelAlias[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
}
if channel == "" {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
@@ -597,11 +825,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr {
entry := arr[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
normalizeCodexKey(&entry)
if entry.BaseURL == "" {
continue
}
@@ -613,12 +837,13 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.CodexModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
@@ -667,12 +892,16 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
@@ -762,6 +991,79 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
entry.Models = normalized
}
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" || model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string][]config.OAuthModelAlias {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.OAuthModelAlias, len(entries))
for channel, aliases := range entries {
if len(aliases) == 0 {
continue
}
copied[channel] = append([]config.OAuthModelAlias(nil), aliases...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelAlias: copied}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias) == 0 {
return nil
}
return cfg.OAuthModelAlias
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
@@ -913,3 +1215,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -24,8 +24,15 @@ import (
type attemptInfo struct {
count int
blockedUntil time.Time
lastActivity time.Time // track last activity for cleanup
}
// attemptCleanupInterval controls how often stale IP entries are purged
const attemptCleanupInterval = 1 * time.Hour
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
const attemptMaxIdleTime = 2 * time.Hour
// Handler aggregates config reference, persistence path and helpers.
type Handler struct {
cfg *config.Config
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
envSecret = strings.TrimSpace(envSecret)
return &Handler{
h := &Handler{
cfg: cfg,
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
@@ -57,6 +64,43 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
}
h.startAttemptCleanup()
return h
}
// startAttemptCleanup launches a background goroutine that periodically
// removes stale IP entries from failedAttempts to prevent memory leaks.
func (h *Handler) startAttemptCleanup() {
go func() {
ticker := time.NewTicker(attemptCleanupInterval)
defer ticker.Stop()
for range ticker.C {
h.purgeStaleAttempts()
}
}()
}
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
// and whose ban (if any) has expired.
func (h *Handler) purgeStaleAttempts() {
now := time.Now()
h.attemptsMu.Lock()
defer h.attemptsMu.Unlock()
for ip, ai := range h.failedAttempts {
// Skip if still banned
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
continue
}
// Remove if idle too long
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
delete(h.failedAttempts, ip)
}
}
}
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
@@ -144,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
h.failedAttempts[clientIP] = aip
}
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0

View File

@@ -13,7 +13,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const (
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {
@@ -272,16 +360,7 @@ func (h *Handler) logDirectory() string {
if h.logDir != "" {
return h.logDir
}
if base := util.WritablePath(); base != "" {
return filepath.Join(base, "logs")
}
if h.configFilePath != "" {
dir := filepath.Dir(h.configFilePath)
if dir != "" && dir != "." {
return filepath.Join(dir, "logs")
}
}
return "logs"
return logging.ResolveLogDirectory(h.cfg)
}
func (h *Handler) collectLogFiles(dir string) ([]string, error) {

View File

@@ -0,0 +1,33 @@
package management
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// GetStaticModelDefinitions returns static model metadata for a given channel.
// Channel is provided via path param (:channel) or query param (?channel=...).
func (h *Handler) GetStaticModelDefinitions(c *gin.Context) {
channel := strings.TrimSpace(c.Param("channel"))
if channel == "" {
channel = strings.TrimSpace(c.Query("channel"))
}
if channel == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "channel is required"})
return
}
models := registry.GetStaticModelDefinitionsByChannel(channel)
if models == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "unknown channel", "channel": channel})
return
}
c.JSON(http.StatusOK, gin.H{
"channel": strings.ToLower(strings.TrimSpace(channel)),
"models": models,
})
}

View File

@@ -238,6 +238,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "qwen", nil
case "kiro":
return "kiro", nil
case "github":
return "github", nil
default:
return "", errUnsupportedOAuthFlow
}

View File

@@ -1,12 +1,25 @@
package management
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
@@ -98,10 +99,12 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
Timestamp: time.Now(),
}, nil
}

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -15,26 +16,29 @@ import (
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
Timestamp time.Time // Timestamp is when the request was received.
}
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
type ResponseWriterWrapper struct {
gin.ResponseWriter
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
firstChunkTimestamp time.Time // firstChunkTimestamp captures TTFB for streaming responses.
}
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
@@ -72,6 +76,10 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
// THEN: Handle logging based on response type
if w.isStreaming && w.chunkChannel != nil {
// Capture TTFB on first chunk (synchronous, before async channel send)
if w.firstChunkTimestamp.IsZero() {
w.firstChunkTimestamp = time.Now()
}
// For streaming responses: Send to async logging channel (non-blocking)
select {
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
@@ -116,6 +124,10 @@ func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
// THEN: Capture for logging
if w.isStreaming && w.chunkChannel != nil {
// Capture TTFB on first chunk (synchronous, before async channel send)
if w.firstChunkTimestamp.IsZero() {
w.firstChunkTimestamp = time.Now()
}
select {
case w.chunkChannel <- []byte(data):
default:
@@ -149,6 +161,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
w.requestInfo.RequestID,
)
if err == nil {
w.streamWriter = streamWriter
@@ -278,6 +291,8 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
w.streamDone = nil
}
w.streamWriter.SetFirstChunkTimestamp(w.firstChunkTimestamp)
// Write API Request and Response to the streaming log before closing
apiRequest := w.extractAPIRequest(c)
if len(apiRequest) > 0 {
@@ -295,7 +310,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -335,7 +350,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist {
return time.Time{}
}
if t, ok := ts.(time.Time); ok {
return t
}
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
@@ -346,7 +372,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -360,6 +386,9 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiResponseBody,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,
apiResponseTimestamp,
)
}
@@ -374,5 +403,8 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiRequestBody,
apiResponseBody,
apiResponseErrors,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,
apiResponseTimestamp,
)
}

View File

@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}
// Check API key change
// Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
@@ -279,16 +300,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
return true
}
// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
// Build map for efficient and robust comparison
type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
}
for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true
}
}
@@ -306,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
return oldKey != newKey
}
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper

View File

@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
})
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -134,10 +135,11 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
}
// Normalize model (handles dynamic thinking suffixes)
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
suffixResult := thinking.ParseSuffix(modelName)
normalizedModel := suffixResult.ModelName
thinkingSuffix := ""
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
thinkingSuffix = modelName[len(normalizedModel):]
if suffixResult.HasSuffix {
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
resolveMappedModel := func() (string, []string) {
@@ -157,13 +159,13 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
if mappedThinkingMetadata == nil {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
}
}
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil

View File

@@ -3,10 +3,12 @@
package amp
import (
"regexp"
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
@@ -26,13 +28,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys)
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
@@ -41,6 +45,11 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
//
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
// However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
if requestedModel == "" {
return ""
@@ -49,23 +58,52 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
m.mu.RLock()
defer m.mu.RUnlock()
// Normalize the requested model for lookup
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
// Normalize the base model for lookup (case-insensitive)
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
// Check for direct mapping using base model name
targetModel, exists := m.mappings[normalizedBase]
if !exists {
return ""
// Try regex mappings in order using base model only
// (suffix is handled separately via ParseSuffix)
for _, rm := range m.regexps {
if rm.re.MatchString(baseModel) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
}
// Verify target model has available providers
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
providers := util.GetProviderName(normalizedTarget)
// Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel)
// Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetResult.ModelName)
if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return ""
}
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
if targetResult.HasSuffix {
// Config's "to" already contains a suffix - use it as-is (config priority)
return targetModel
}
// Preserve user's thinking suffix on the mapped model
// (skip empty suffixes to avoid returning "model()")
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return targetModel + "(" + requestResult.RawSuffix + ")"
}
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
return targetModel
}
@@ -78,6 +116,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
@@ -88,16 +127,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue
}
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
@@ -111,3 +164,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string
}

View File

@@ -203,3 +203,173 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified")
}
}
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro(high)" {
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_SuffixPreservation(t *testing.T) {
reg := registry.GetGlobalRegistry()
// Register test models
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-suffix")
defer reg.UnregisterClient("test-client-suffix-2")
tests := []struct {
name string
mappings []config.AmpModelMapping
input string
want string
}{
{
name: "numeric suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "level suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "no suffix unchanged",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p",
want: "gemini-2.5-pro",
},
{
name: "config suffix takes priority",
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
input: "alias(high)",
want: "gemini-2.5-pro(medium)",
},
{
name: "regex with suffix preserved",
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "auto suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(auto)",
want: "gemini-2.5-pro(auto)",
},
{
name: "none suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(none)",
want: "gemini-2.5-pro(none)",
},
{
name: "case insensitive base lookup with suffix",
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "empty suffix filtered out",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p()",
want: "gemini-2.5-pro",
},
{
name: "incomplete suffix treated as no suffix",
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
input: "g25p(high",
want: "gemini-2.5-pro",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mapper := NewModelMapper(tt.mappings)
got := mapper.MapModel(tt.input)
if got != tt.want {
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

View File

@@ -18,6 +18,33 @@ import (
log "github.com/sirupsen/logrus"
)
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct {
@@ -48,6 +75,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" {

View File

@@ -3,11 +3,15 @@ package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))

View File

@@ -125,7 +125,30 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
if rw.originalModel == "" {
return data
}

View File

@@ -1,6 +1,7 @@
package amp
import (
"context"
"errors"
"net"
"net/http"
@@ -16,6 +17,37 @@ import (
log "github.com/sirupsen/logrus"
)
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
}
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass)
}
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
if auth != nil {
ampProviders.Use(auth)
}
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider")

View File

@@ -9,6 +9,9 @@ import (
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// SecretSource provides Amp API keys with configurable precedence and caching
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil
}
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}

View File

@@ -8,6 +8,10 @@ import (
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
}
}
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}

View File

@@ -12,6 +12,7 @@ import (
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"sync/atomic"
@@ -23,9 +24,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
@@ -33,6 +36,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
@@ -209,13 +213,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
if !cfg.CommercialMode {
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
}
@@ -251,15 +257,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
managementasset.SetCurrentConfig(cfg)
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
// Initialize management handler
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
if optionState.localPassword != "" {
s.mgmt.SetLocalPassword(optionState.localPassword)
}
logDir := filepath.Join(s.currentPath, "logs")
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
s.localPassword = optionState.localPassword
@@ -290,6 +294,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.registerManagementRoutes()
}
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
kiroOAuthHandler.RegisterRoutes(engine)
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
@@ -323,6 +332,7 @@ func (s *Server) setupRoutes() {
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
// Gemini compatible API routes
@@ -494,6 +504,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -507,6 +519,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
@@ -516,6 +532,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
@@ -538,6 +556,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
@@ -564,6 +583,10 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -572,6 +595,14 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
@@ -587,16 +618,28 @@ func (s *Server) registerManagementRoutes() {
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
mgmt.GET("/model-definitions/:channel", s.mgmt.GetStaticModelDefinitions)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
@@ -607,6 +650,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
@@ -866,7 +910,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else {
if oldCfg == nil {
@@ -899,6 +943,16 @@ func (s *Server) UpdateClients(cfg *config.Config) {
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
}
}
if oldCfg == nil || oldCfg.CodexInstructionsEnabled != cfg.CodexInstructionsEnabled {
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
if oldCfg != nil {
log.Debugf("codex_instructions_enabled updated from %t to %t", oldCfg.CodexInstructionsEnabled, cfg.CodexInstructionsEnabled)
} else {
log.Debugf("codex_instructions_enabled toggled to %t", cfg.CodexInstructionsEnabled)
}
}
if s.handlers != nil && s.handlers.AuthManager != nil {
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
}
@@ -966,18 +1020,25 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.mgmt.SetAuthManager(s.handlers.AuthManager)
}
// Notify Amp module of config changes (for model mapping hot-reload)
if s.ampModule != nil {
log.Debugf("triggering amp module config update")
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
log.Errorf("failed to update Amp module config: %v", err)
// Notify Amp module only when Amp config has changed.
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
if ampConfigChanged {
if s.ampModule != nil {
log.Debugf("triggering amp module config update")
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
log.Errorf("failed to update Amp module config: %v", err)
}
} else {
log.Warnf("amp module is nil, skipping config update")
}
} else {
log.Warnf("amp module is nil, skipping config update")
}
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
// Count client sources from configuration and auth store.
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
@@ -988,10 +1049,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
openAICompatCount += len(entry.APIKeyEntries)
}
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authFiles,
authEntries,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,

View File

@@ -0,0 +1,344 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// TokenResponse represents OAuth token response from Google
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
// userInfo represents Google user profile
type userInfo struct {
Email string `json:"email"`
}
// AntigravityAuth handles Antigravity OAuth authentication
type AntigravityAuth struct {
httpClient *http.Client
}
// NewAntigravityAuth creates a new Antigravity auth service.
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
if httpClient != nil {
return &AntigravityAuth{httpClient: httpClient}
}
if cfg == nil {
cfg = &config.Config{}
}
return &AntigravityAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// BuildAuthURL generates the OAuth authorization URL.
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
if strings.TrimSpace(redirectURI) == "" {
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
}
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", ClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(Scopes, " "))
params.Set("state", state)
return AuthEndpoint + "?" + params.Encode()
}
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", ClientID)
data.Set("client_secret", ClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
}
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
}
var token TokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
}
return &token, nil
}
// FetchUserInfo retrieves user email from Google
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", fmt.Errorf("antigravity userinfo: missing access token")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
if err != nil {
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
}
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
}
var info userInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
}
email := strings.TrimSpace(info.Email)
if email == "" {
return "", fmt.Errorf("antigravity userinfo: response missing email")
}
return email, nil
}
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}

View File

@@ -0,0 +1,34 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
// OAuth client credentials and configuration
const (
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
CallbackPort = 51121
)
// Scopes defines the OAuth scopes required for Antigravity authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// OAuth2 endpoints for Google authentication
const (
TokenEndpoint = "https://oauth2.googleapis.com/token"
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
)
// Antigravity API configuration
const (
APIEndpoint = "https://cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
APIUserAgent = "google-api-nodejs-client/9.15.1"
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)

View File

@@ -0,0 +1,16 @@
package antigravity
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Antigravity credentials.
// It uses the email as a suffix to disambiguate accounts.
func CredentialFileName(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return "antigravity.json"
}
return fmt.Sprintf("antigravity-%s.json", email)
}

View File

@@ -14,15 +14,15 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for Claude/Anthropic
const (
anthropicAuthURL = "https://claude.ai/oauth/authorize"
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
redirectURI = "http://localhost:54545/callback"
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
@@ -50,7 +50,8 @@ type ClaudeAuth struct {
}
// NewClaudeAuth creates a new Anthropic authentication service.
// It initializes the HTTP client with proxy settings from the configuration.
// It initializes the HTTP client with a custom TLS transport that uses Firefox
// fingerprint to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
//
// Parameters:
// - cfg: The application configuration containing proxy settings
@@ -58,8 +59,10 @@ type ClaudeAuth struct {
// Returns:
// - *ClaudeAuth: A new Claude authentication service instance
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
// Use custom HTTP client with Firefox TLS fingerprint to bypass
// Cloudflare's bot detection on Anthropic domains
return &ClaudeAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
}
}
@@ -82,16 +85,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
params := url.Values{
"code": {"true"},
"client_id": {anthropicClientID},
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"},
"code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, state, nil
}
@@ -137,8 +140,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
"code": newCode,
"state": state,
"grant_type": "authorization_code",
"client_id": anthropicClientID,
"redirect_uri": redirectURI,
"client_id": ClientID,
"redirect_uri": RedirectURI,
"code_verifier": pkceCodes.CodeVerifier,
}
@@ -154,7 +157,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
// log.Debugf("Token exchange request: %s", string(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -221,7 +224,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
reqBody := map[string]interface{}{
"client_id": anthropicClientID,
"client_id": ClientID,
"grant_type": "refresh_token",
"refresh_token": refreshToken,
}
@@ -231,7 +234,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}

View File

@@ -0,0 +1,165 @@
// Package claude provides authentication functionality for Anthropic's Claude API.
// This file implements a custom HTTP transport using utls to bypass TLS fingerprinting.
package claude
import (
"net/http"
"net/url"
"strings"
"sync"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
)
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
type utlsRoundTripper struct {
// mu protects the connections map and pending map
mu sync.Mutex
// connections caches HTTP/2 client connections per host
connections map[string]*http2.ClientConn
// pending tracks hosts that are currently being connected to (prevents race condition)
pending map[string]*sync.Cond
// dialer is used to create network connections, supporting proxies
dialer proxy.Dialer
}
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" {
proxyURL, err := url.Parse(cfg.ProxyURL)
if err != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
} else {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
}
}
return &utlsRoundTripper{
connections: make(map[string]*http2.ClientConn),
pending: make(map[string]*sync.Cond),
dialer: dialer,
}
}
// getOrCreateConnection gets an existing connection or creates a new one.
// It uses a per-host locking mechanism to prevent multiple goroutines from
// creating connections to the same host simultaneously.
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
t.mu.Lock()
// Check if connection exists and is usable
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
// Check if another goroutine is already creating a connection
if cond, ok := t.pending[host]; ok {
// Wait for the other goroutine to finish
cond.Wait()
// Check if connection is now available
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
// Connection still not available, we'll create one
}
// Mark this host as pending
cond := sync.NewCond(&t.mu)
t.pending[host] = cond
t.mu.Unlock()
// Create connection outside the lock
h2Conn, err := t.createConnection(host, addr)
t.mu.Lock()
defer t.mu.Unlock()
// Remove pending marker and wake up waiting goroutines
delete(t.pending, host)
cond.Broadcast()
if err != nil {
return nil, err
}
// Store the new connection
t.connections[host] = h2Conn
return h2Conn, nil
}
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := t.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{ServerName: host}
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
tr := &http2.Transport{}
h2Conn, err := tr.NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
return h2Conn, nil
}
// RoundTrip implements http.RoundTripper
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.URL.Host
addr := host
if !strings.Contains(addr, ":") {
addr += ":443"
}
// Get hostname without port for TLS ServerName
hostname := req.URL.Hostname()
h2Conn, err := t.getOrCreateConnection(hostname, addr)
if err != nil {
return nil, err
}
resp, err := h2Conn.RoundTrip(req)
if err != nil {
// Connection failed, remove it from cache
t.mu.Lock()
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
delete(t.connections, hostname)
}
t.mu.Unlock()
return nil, err
}
return resp, nil
}
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
// for Anthropic domains by using utls with Firefox fingerprint.
// It accepts optional SDK configuration for proxy settings.
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
return &http.Client{
Transport: newUtlsRoundTripper(cfg),
}
}

View File

@@ -0,0 +1,46 @@
package codex
import (
"fmt"
"strings"
"unicode"
)
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
// When planType is available (e.g. "plus", "team"), it is appended after the email
// as a suffix to disambiguate subscriptions.
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
email = strings.TrimSpace(email)
plan := normalizePlanTypeForFilename(planType)
prefix := ""
if includeProviderPrefix {
prefix = "codex"
}
if plan == "" {
return fmt.Sprintf("%s-%s.json", prefix, email)
} else if plan == "team" {
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
}
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
}
func normalizePlanTypeForFilename(planType string) string {
planType = strings.TrimSpace(planType)
if planType == "" {
return ""
}
parts := strings.FieldsFunc(planType, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
})
if len(parts) == 0 {
return ""
}
for i, part := range parts {
parts[i] = strings.ToLower(strings.TrimSpace(part))
}
return strings.Join(parts, "-")
}

View File

@@ -19,11 +19,12 @@ import (
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for OpenAI Codex
const (
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
openaiTokenURL = "https://auth.openai.com/oauth/token"
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
redirectURI = "http://localhost:1455/auth/callback"
AuthURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
RedirectURI = "http://localhost:1455/auth/callback"
)
// CodexAuth handles the OpenAI OAuth2 authentication flow.
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
}
params := url.Values{
"client_id": {openaiClientID},
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"scope": {"openid email profile offline_access"},
"state": {state},
"code_challenge": {pkceCodes.CodeChallenge},
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
"codex_cli_simplified_flow": {"true"},
}
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, nil
}
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {openaiClientID},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {redirectURI},
"redirect_uri": {RedirectURI},
"code_verifier": {pkceCodes.CodeVerifier},
}
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
}
data := url.Values{
"client_id": {openaiClientID},
"client_id": {ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"scope": {"openid profile email"},
}
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}

View File

@@ -28,18 +28,19 @@ import (
"golang.org/x/oauth2/google"
)
// OAuth configuration constants for Gemini
const (
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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"
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ClientSecret = "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"
DefaultCallbackPort = 8085
)
var (
geminiOauthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
)
// OAuth scopes for Gemini authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
@@ -49,8 +50,9 @@ type GeminiAuth struct {
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
Prompt func(string) (string, error)
NoBrowser bool
CallbackPort int
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth.
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
// - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
@@ -104,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: geminiOauthClientID,
ClientSecret: geminiOauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
Scopes: geminiOauthScopes,
ClientID: ClientID,
ClientSecret: ClientSecret,
RedirectURL: callbackURL, // This will be used by the local server.
Scopes: Scopes,
Endpoint: google.Endpoint,
}
@@ -190,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = geminiOauthClientID
ifToken["client_secret"] = geminiOauthClientSecret
ifToken["scopes"] = geminiOauthScopes
ifToken["client_id"] = ClientID
ifToken["client_secret"] = ClientSecret
ifToken["scopes"] = Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := GeminiTokenStorage{
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback"
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
config.RedirectURL = callbackURL
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Check if browser is available
if !browser.IsAvailable() {
log.Warn("No browser available on this system")
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
} else {
if err := browser.OpenURL(authURL); err != nil {
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
log.Warn(codex.GetUserFriendlyMessage(authErr))
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
// Log platform info for debugging
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
}
}
} else {
util.PrintSSHTunnelInstructions(8085)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
}

View File

@@ -5,10 +5,12 @@ package kiro
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
@@ -30,16 +32,23 @@ type KiroTokenData struct {
ProfileArn string `json:"profileArn"`
// ExpiresAt is the timestamp when the token expires
ExpiresAt string `json:"expiresAt"`
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
AuthMethod string `json:"authMethod"`
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
// Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
Provider string `json:"provider"`
// ClientID is the OIDC client ID (needed for token refresh)
ClientID string `json:"clientId,omitempty"`
// ClientSecret is the OIDC client secret (needed for token refresh)
ClientSecret string `json:"clientSecret,omitempty"`
// ClientIDHash is the hash of client ID used to locate device registration file
// (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
ClientIDHash string `json:"clientIdHash,omitempty"`
// Email is the user's email address (used for file naming)
Email string `json:"email,omitempty"`
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
StartURL string `json:"startUrl,omitempty"`
// Region is the AWS region for IDC authentication (only for IDC auth method)
Region string `json:"region,omitempty"`
}
// KiroAuthBundle aggregates authentication data after OAuth flow completion
@@ -81,7 +90,90 @@ type KiroModel struct {
// KiroIDETokenFile is the default path to Kiro IDE's token file
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
// Default retry configuration for file reading
const (
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
)
// isTransientFileError checks if the error is a transient file access error
// that may be resolved by retrying (e.g., file locked by another process on Windows).
func isTransientFileError(err error) bool {
if err == nil {
return false
}
// Check for OS-level file access errors (Windows sharing violation, etc.)
var pathErr *os.PathError
if errors.As(err, &pathErr) {
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
errStr := pathErr.Err.Error()
if strings.Contains(errStr, "being used by another process") ||
strings.Contains(errStr, "sharing violation") ||
strings.Contains(errStr, "lock violation") {
return true
}
}
// Check error message for common transient patterns
errMsg := strings.ToLower(err.Error())
transientPatterns := []string{
"being used by another process",
"sharing violation",
"lock violation",
"access is denied",
"unexpected end of json",
"unexpected eof",
}
for _, pattern := range transientPatterns {
if strings.Contains(errMsg, pattern) {
return true
}
}
return false
}
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
if maxAttempts <= 0 {
maxAttempts = defaultTokenReadMaxAttempts
}
if baseDelay <= 0 {
baseDelay = defaultTokenReadBaseDelay
}
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
token, err := LoadKiroIDEToken()
if err == nil {
return token, nil
}
lastErr = err
// Only retry for transient errors
if !isTransientFileError(err) {
return nil, err
}
// Exponential backoff: delay * 2^attempt, capped at 500ms
delay := baseDelay * time.Duration(1<<uint(attempt))
if delay > 500*time.Millisecond {
delay = 500 * time.Millisecond
}
time.Sleep(delay)
}
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
}
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
// from the device registration file referenced by clientIdHash.
func LoadKiroIDEToken() (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
@@ -103,18 +195,72 @@ func LoadKiroIDEToken() (*KiroTokenData, error) {
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
}
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
token.AuthMethod = strings.ToLower(token.AuthMethod)
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
// The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
if token.ClientIDHash != "" && token.ClientID == "" {
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
// Log warning but don't fail - token might still work for some operations
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
}
}
return &token, nil
}
// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
if clientIDHash == "" {
return fmt.Errorf("clientIdHash is empty")
}
// Sanitize clientIdHash to prevent path traversal
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
return fmt.Errorf("invalid clientIdHash: contains path separator")
}
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
data, err := os.ReadFile(deviceRegPath)
if err != nil {
return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
}
// Device registration file structure
var deviceReg struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ExpiresAt string `json:"expiresAt"`
}
if err := json.Unmarshal(data, &deviceReg); err != nil {
return fmt.Errorf("failed to parse device registration: %w", err)
}
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
return fmt.Errorf("device registration missing clientId or clientSecret")
}
token.ClientID = deviceReg.ClientID
token.ClientSecret = deviceReg.ClientSecret
return nil
}
// LoadKiroTokenFromPath loads token data from a custom path.
// This supports multiple accounts by allowing different token files.
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
// from the device registration file referenced by clientIdHash.
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
// Expand ~ to home directory
if len(tokenPath) > 0 && tokenPath[0] == '~' {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenPath = filepath.Join(homeDir, tokenPath[1:])
}
@@ -132,6 +278,17 @@ func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
return nil, fmt.Errorf("access token is empty in token file")
}
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
token.AuthMethod = strings.ToLower(token.AuthMethod)
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
if token.ClientIDHash != "" && token.ClientID == "" {
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
// Log warning but don't fail - token might still work for some operations
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
}
}
return &token, nil
}
@@ -267,7 +424,7 @@ func SanitizeEmailForFilename(email string) string {
}
result := email
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
// This prevents encoded characters from bypassing the sanitization.
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
@@ -285,7 +442,7 @@ func SanitizeEmailForFilename(email string) string {
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
result = strings.ReplaceAll(result, char, "_")
}
// Prevent path traversal: replace leading dots in each path component
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
parts := strings.Split(result, "_")
@@ -296,6 +453,65 @@ func SanitizeEmailForFilename(email string) string {
parts[i] = part
}
result = strings.Join(parts, "_")
return result
}
// ExtractIDCIdentifier extracts a unique identifier from IDC startUrl.
// Examples:
// - "https://d-1234567890.awsapps.com/start" -> "d-1234567890"
// - "https://my-company.awsapps.com/start" -> "my-company"
// - "https://acme-corp.awsapps.com/start" -> "acme-corp"
func ExtractIDCIdentifier(startURL string) string {
if startURL == "" {
return ""
}
// Remove protocol prefix
url := strings.TrimPrefix(startURL, "https://")
url = strings.TrimPrefix(url, "http://")
// Extract subdomain (first part before the first dot)
// Format: {identifier}.awsapps.com/start
parts := strings.Split(url, ".")
if len(parts) > 0 && parts[0] != "" {
identifier := parts[0]
// Sanitize for filename safety
identifier = strings.ReplaceAll(identifier, "/", "_")
identifier = strings.ReplaceAll(identifier, "\\", "_")
identifier = strings.ReplaceAll(identifier, ":", "_")
return identifier
}
return ""
}
// GenerateTokenFileName generates a unique filename for token storage.
// Priority: email > startUrl identifier (for IDC) > authMethod only
// Format: kiro-{authMethod}-{identifier}.json
func GenerateTokenFileName(tokenData *KiroTokenData) string {
authMethod := tokenData.AuthMethod
if authMethod == "" {
authMethod = "unknown"
}
// Priority 1: Use email if available
if tokenData.Email != "" {
// Sanitize email for filename (replace @ and . with -)
sanitizedEmail := tokenData.Email
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
}
// Priority 2: For IDC, use startUrl identifier
if authMethod == "idc" && tokenData.StartURL != "" {
identifier := ExtractIDCIdentifier(tokenData.StartURL)
if identifier != "" {
return fmt.Sprintf("kiro-%s-%s.json", authMethod, identifier)
}
}
// Priority 3: Fallback to authMethod only
return fmt.Sprintf("kiro-%s.json", authMethod)
}

View File

@@ -280,6 +280,11 @@ func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorag
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
}
@@ -311,4 +316,19 @@ func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *Kiro
storage.AuthMethod = tokenData.AuthMethod
storage.Provider = tokenData.Provider
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ClientID != "" {
storage.ClientID = tokenData.ClientID
}
if tokenData.ClientSecret != "" {
storage.ClientSecret = tokenData.ClientSecret
}
if tokenData.Region != "" {
storage.Region = tokenData.Region
}
if tokenData.StartURL != "" {
storage.StartURL = tokenData.StartURL
}
if tokenData.Email != "" {
storage.Email = tokenData.Email
}
}

View File

@@ -151,11 +151,161 @@ func TestSanitizeEmailForFilename(t *testing.T) {
// createTestJWT creates a test JWT token with the given claims
func createTestJWT(claims map[string]any) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payloadBytes, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
return header + "." + payload + "." + signature
}
func TestExtractIDCIdentifier(t *testing.T) {
tests := []struct {
name string
startURL string
expected string
}{
{
name: "Empty URL",
startURL: "",
expected: "",
},
{
name: "Standard IDC URL with d- prefix",
startURL: "https://d-1234567890.awsapps.com/start",
expected: "d-1234567890",
},
{
name: "IDC URL with company name",
startURL: "https://my-company.awsapps.com/start",
expected: "my-company",
},
{
name: "IDC URL with simple name",
startURL: "https://acme-corp.awsapps.com/start",
expected: "acme-corp",
},
{
name: "IDC URL without https",
startURL: "http://d-9876543210.awsapps.com/start",
expected: "d-9876543210",
},
{
name: "IDC URL with subdomain only",
startURL: "https://test.awsapps.com/start",
expected: "test",
},
{
name: "Builder ID URL",
startURL: "https://view.awsapps.com/start",
expected: "view",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractIDCIdentifier(tt.startURL)
if result != tt.expected {
t.Errorf("ExtractIDCIdentifier() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGenerateTokenFileName(t *testing.T) {
tests := []struct {
name string
tokenData *KiroTokenData
expected string
}{
{
name: "IDC with email",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "user@example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-example-com.json",
},
{
name: "IDC without email but with startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-d-1234567890.json",
},
{
name: "IDC with company name in startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "https://my-company.awsapps.com/start",
},
expected: "kiro-idc-my-company.json",
},
{
name: "IDC without email and without startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "",
},
expected: "kiro-idc.json",
},
{
name: "Builder ID with email",
tokenData: &KiroTokenData{
AuthMethod: "builder-id",
Email: "user@gmail.com",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id-user-gmail-com.json",
},
{
name: "Builder ID without email",
tokenData: &KiroTokenData{
AuthMethod: "builder-id",
Email: "",
StartURL: "https://view.awsapps.com/start",
},
expected: "kiro-builder-id.json",
},
{
name: "Social auth with email",
tokenData: &KiroTokenData{
AuthMethod: "google",
Email: "user@gmail.com",
},
expected: "kiro-google-user-gmail-com.json",
},
{
name: "Empty auth method",
tokenData: &KiroTokenData{
AuthMethod: "",
Email: "",
},
expected: "kiro-unknown.json",
},
{
name: "Email with special characters",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "user.name+tag@sub.example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
expected: "kiro-idc-user-name+tag-sub-example-com.json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateTokenFileName(tt.tokenData)
if result != tt.expected {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,247 @@
package kiro
import (
"context"
"log"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"golang.org/x/sync/semaphore"
)
type Token struct {
ID string
AccessToken string
RefreshToken string
ExpiresAt time.Time
LastVerified time.Time
ClientID string
ClientSecret string
AuthMethod string
Provider string
StartURL string
Region string
}
type TokenRepository interface {
FindOldestUnverified(limit int) []*Token
UpdateToken(token *Token) error
}
type RefresherOption func(*BackgroundRefresher)
func WithInterval(interval time.Duration) RefresherOption {
return func(r *BackgroundRefresher) {
r.interval = interval
}
}
func WithBatchSize(size int) RefresherOption {
return func(r *BackgroundRefresher) {
r.batchSize = size
}
}
func WithConcurrency(concurrency int) RefresherOption {
return func(r *BackgroundRefresher) {
r.concurrency = concurrency
}
}
type BackgroundRefresher struct {
interval time.Duration
batchSize int
concurrency int
tokenRepo TokenRepository
stopCh chan struct{}
wg sync.WaitGroup
oauth *KiroOAuth
ssoClient *SSOOIDCClient
callbackMu sync.RWMutex // 保护回调函数的并发访问
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
}
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
r := &BackgroundRefresher{
interval: time.Minute,
batchSize: 50,
concurrency: 10,
tokenRepo: repo,
stopCh: make(chan struct{}),
oauth: nil, // Lazy init - will be set when config available
ssoClient: nil, // Lazy init - will be set when config available
}
for _, opt := range opts {
opt(r)
}
return r
}
// WithConfig sets the configuration for OAuth and SSO clients.
func WithConfig(cfg *config.Config) RefresherOption {
return func(r *BackgroundRefresher) {
r.oauth = NewKiroOAuth(cfg)
r.ssoClient = NewSSOOIDCClient(cfg)
}
}
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
// The callback receives the token ID (filename) and the new token data.
// This allows external components (e.g., Watcher) to be notified of token updates.
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
return func(r *BackgroundRefresher) {
r.callbackMu.Lock()
r.onTokenRefreshed = callback
r.callbackMu.Unlock()
}
}
func (r *BackgroundRefresher) Start(ctx context.Context) {
r.wg.Add(1)
go func() {
defer r.wg.Done()
ticker := time.NewTicker(r.interval)
defer ticker.Stop()
r.refreshBatch(ctx)
for {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-ticker.C:
r.refreshBatch(ctx)
}
}
}()
}
func (r *BackgroundRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
}
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
if len(tokens) == 0 {
return
}
sem := semaphore.NewWeighted(int64(r.concurrency))
var wg sync.WaitGroup
for i, token := range tokens {
if i > 0 {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-time.After(100 * time.Millisecond):
}
}
if err := sem.Acquire(ctx, 1); err != nil {
return
}
wg.Add(1)
go func(t *Token) {
defer wg.Done()
defer sem.Release(1)
r.refreshSingle(ctx, t)
}(token)
}
wg.Wait()
}
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
// Normalize auth method to lowercase for case-insensitive matching
authMethod := strings.ToLower(token.AuthMethod)
// Create refresh function based on auth method
refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
switch authMethod {
case "idc":
return r.ssoClient.RefreshTokenWithRegion(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
token.Region,
token.StartURL,
)
case "builder-id":
return r.ssoClient.RefreshToken(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
)
default:
return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
}
}
// Use graceful degradation for better reliability
result := RefreshWithGracefulDegradation(
ctx,
refreshFunc,
token.AccessToken,
token.ExpiresAt,
)
if result.Error != nil {
log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
return
}
newTokenData := result.TokenData
if result.UsedFallback {
log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
// Don't update the token file if we're using fallback
// Just update LastVerified to prevent immediate re-check
token.LastVerified = time.Now()
return
}
token.AccessToken = newTokenData.AccessToken
if newTokenData.RefreshToken != "" {
token.RefreshToken = newTokenData.RefreshToken
}
token.LastVerified = time.Now()
if newTokenData.ExpiresAt != "" {
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
token.ExpiresAt = expTime
}
}
if err := r.tokenRepo.UpdateToken(token); err != nil {
log.Printf("failed to update token %s: %v", token.ID, err)
return
}
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
r.callbackMu.RLock()
callback := r.onTokenRefreshed
r.callbackMu.RUnlock()
if callback != nil {
// 使用 defer recover 隔离回调 panic防止崩溃整个进程
func() {
defer func() {
if rec := recover(); rec != nil {
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
}
}()
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
callback(token.ID, newTokenData)
}()
}
}

View File

@@ -0,0 +1,112 @@
package kiro
import (
"sync"
"time"
)
const (
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
CooldownReasonQuotaExhausted = "quota_exhausted"
DefaultShortCooldown = 1 * time.Minute
MaxShortCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
)
type CooldownManager struct {
mu sync.RWMutex
cooldowns map[string]time.Time
reasons map[string]string
}
func NewCooldownManager() *CooldownManager {
return &CooldownManager{
cooldowns: make(map[string]time.Time),
reasons: make(map[string]string),
}
}
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.cooldowns[tokenKey] = time.Now().Add(duration)
cm.reasons[tokenKey] = reason
}
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return false
}
return time.Now().Before(endTime)
}
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return 0
}
remaining := time.Until(endTime)
if remaining < 0 {
return 0
}
return remaining
}
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.reasons[tokenKey]
}
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
cm.mu.Lock()
defer cm.mu.Unlock()
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
func (cm *CooldownManager) CleanupExpired() {
cm.mu.Lock()
defer cm.mu.Unlock()
now := time.Now()
for tokenKey, endTime := range cm.cooldowns {
if now.After(endTime) {
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
}
}
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cm.CleanupExpired()
case <-stopCh:
return
}
}
}
func CalculateCooldownFor429(retryCount int) time.Duration {
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
if duration > MaxShortCooldown {
return MaxShortCooldown
}
return duration
}
func CalculateCooldownUntilNextDay() time.Duration {
now := time.Now()
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
return time.Until(nextDay)
}

View File

@@ -0,0 +1,240 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewCooldownManager(t *testing.T) {
cm := NewCooldownManager()
if cm == nil {
t.Fatal("expected non-nil CooldownManager")
}
if cm.cooldowns == nil {
t.Error("expected non-nil cooldowns map")
}
if cm.reasons == nil {
t.Error("expected non-nil reasons map")
}
}
func TestSetCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
if !cm.IsInCooldown("token1") {
t.Error("expected token to be in cooldown")
}
if cm.GetCooldownReason("token1") != CooldownReason429 {
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
}
}
func TestIsInCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
if cm.IsInCooldown("nonexistent") {
t.Error("expected non-existent token to not be in cooldown")
}
}
func TestIsInCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
if cm.IsInCooldown("token1") {
t.Error("expected expired cooldown to return false")
}
}
func TestGetRemainingCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
remaining := cm.GetRemainingCooldown("token1")
if remaining <= 0 || remaining > 1*time.Second {
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
}
}
func TestGetRemainingCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
remaining := cm.GetRemainingCooldown("nonexistent")
if remaining != 0 {
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
}
}
func TestGetRemainingCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
remaining := cm.GetRemainingCooldown("token1")
if remaining != 0 {
t.Errorf("expected 0 remaining for expired, got %v", remaining)
}
}
func TestGetCooldownReason(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
}
}
func TestGetCooldownReason_NotSet(t *testing.T) {
cm := NewCooldownManager()
reason := cm.GetCooldownReason("nonexistent")
if reason != "" {
t.Errorf("expected empty reason for non-existent, got %s", reason)
}
}
func TestClearCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
cm.ClearCooldown("token1")
if cm.IsInCooldown("token1") {
t.Error("expected cooldown to be cleared")
}
if cm.GetCooldownReason("token1") != "" {
t.Error("expected reason to be cleared")
}
}
func TestClearCooldown_NonExistent(t *testing.T) {
cm := NewCooldownManager()
cm.ClearCooldown("nonexistent")
}
func TestCleanupExpired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
time.Sleep(10 * time.Millisecond)
cm.CleanupExpired()
if cm.GetCooldownReason("expired1") != "" {
t.Error("expected expired1 to be cleaned up")
}
if cm.GetCooldownReason("expired2") != "" {
t.Error("expected expired2 to be cleaned up")
}
if cm.GetCooldownReason("active") != CooldownReason429 {
t.Error("expected active to remain")
}
}
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
duration := CalculateCooldownFor429(0)
if duration != DefaultShortCooldown {
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
}
}
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
d1 := CalculateCooldownFor429(1)
d2 := CalculateCooldownFor429(2)
if d2 <= d1 {
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
}
}
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
duration := CalculateCooldownFor429(10)
if duration > MaxShortCooldown {
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
}
}
func TestCalculateCooldownUntilNextDay(t *testing.T) {
duration := CalculateCooldownUntilNextDay()
if duration <= 0 || duration > 24*time.Hour {
t.Errorf("expected duration between 0 and 24h, got %v", duration)
}
}
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
cm := NewCooldownManager()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
case 1:
cm.IsInCooldown(tokenKey)
case 2:
cm.GetRemainingCooldown(tokenKey)
case 3:
cm.GetCooldownReason(tokenKey)
case 4:
cm.ClearCooldown(tokenKey)
case 5:
cm.CleanupExpired()
}
}
}(i)
}
wg.Wait()
}
func TestCooldownReasonConstants(t *testing.T) {
if CooldownReason429 != "rate_limit_exceeded" {
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
}
if CooldownReasonSuspended != "account_suspended" {
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
}
if CooldownReasonQuotaExhausted != "quota_exhausted" {
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
}
}
func TestDefaultConstants(t *testing.T) {
if DefaultShortCooldown != 1*time.Minute {
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
}
if MaxShortCooldown != 5*time.Minute {
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
}
if LongCooldown != 24*time.Hour {
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
}
}
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
}
remaining := cm.GetRemainingCooldown("token1")
if remaining > 1*time.Minute {
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
}
}

View File

@@ -0,0 +1,197 @@
package kiro
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"sync"
"time"
)
// Fingerprint 多维度指纹信息
type Fingerprint struct {
SDKVersion string // 1.0.20-1.0.27
OSType string // darwin/windows/linux
OSVersion string // 10.0.22621
NodeVersion string // 18.x/20.x/22.x
KiroVersion string // 0.3.x-0.8.x
KiroHash string // SHA256
AcceptLanguage string
ScreenResolution string // 1920x1080
ColorDepth int // 24
HardwareConcurrency int // CPU 核心数
TimezoneOffset int
}
// FingerprintManager 指纹管理器
type FingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
rng *rand.Rand
}
var (
sdkVersions = []string{
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
}
osTypes = []string{"darwin", "windows", "linux"}
osVersions = map[string][]string{
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
}
nodeVersions = []string{
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
}
kiroVersions = []string{
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
}
acceptLanguages = []string{
"en-US,en;q=0.9",
"en-GB,en;q=0.9",
"zh-CN,zh;q=0.9,en;q=0.8",
"zh-TW,zh;q=0.9,en;q=0.8",
"ja-JP,ja;q=0.9,en;q=0.8",
"ko-KR,ko;q=0.9,en;q=0.8",
"de-DE,de;q=0.9,en;q=0.8",
"fr-FR,fr;q=0.9,en;q=0.8",
}
screenResolutions = []string{
"1920x1080", "2560x1440", "3840x2160",
"1366x768", "1440x900", "1680x1050",
"2560x1600", "3440x1440",
}
colorDepths = []int{24, 32}
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
)
// NewFingerprintManager 创建指纹管理器
func NewFingerprintManager() *FingerprintManager {
return &FingerprintManager{
fingerprints: make(map[string]*Fingerprint),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// GetFingerprint 获取或生成 Token 关联的指纹
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
fm.mu.RLock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
fm.mu.RUnlock()
return fp
}
fm.mu.RUnlock()
fm.mu.Lock()
defer fm.mu.Unlock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
return fp
}
fp := fm.generateFingerprint(tokenKey)
fm.fingerprints[tokenKey] = fp
return fp
}
// generateFingerprint 生成新的指纹
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
osType := fm.randomChoice(osTypes)
osVersion := fm.randomChoice(osVersions[osType])
kiroVersion := fm.randomChoice(kiroVersions)
fp := &Fingerprint{
SDKVersion: fm.randomChoice(sdkVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: fm.randomChoice(nodeVersions),
KiroVersion: kiroVersion,
AcceptLanguage: fm.randomChoice(acceptLanguages),
ScreenResolution: fm.randomChoice(screenResolutions),
ColorDepth: fm.randomIntChoice(colorDepths),
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
}
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
return fp
}
// generateKiroHash 生成 Kiro Hash
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
// randomChoice 随机选择字符串
func (fm *FingerprintManager) randomChoice(choices []string) string {
return choices[fm.rng.Intn(len(choices))]
}
// randomIntChoice 随机选择整数
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
return choices[fm.rng.Intn(len(choices))]
}
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
req.Header.Set("Accept-Language", fp.AcceptLanguage)
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
}
// RemoveFingerprint 移除 Token 关联的指纹
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
fm.mu.Lock()
defer fm.mu.Unlock()
delete(fm.fingerprints, tokenKey)
}
// Count 返回当前管理的指纹数量
func (fm *FingerprintManager) Count() int {
fm.mu.RLock()
defer fm.mu.RUnlock()
return len(fm.fingerprints)
}
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.SDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.SDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildAmzUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.SDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}

View File

@@ -0,0 +1,227 @@
package kiro
import (
"net/http"
"sync"
"testing"
)
func TestNewFingerprintManager(t *testing.T) {
fm := NewFingerprintManager()
if fm == nil {
t.Fatal("expected non-nil FingerprintManager")
}
if fm.fingerprints == nil {
t.Error("expected non-nil fingerprints map")
}
if fm.rng == nil {
t.Error("expected non-nil rng")
}
}
func TestGetFingerprint_NewToken(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if fp == nil {
t.Fatal("expected non-nil Fingerprint")
}
if fp.SDKVersion == "" {
t.Error("expected non-empty SDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.OSVersion == "" {
t.Error("expected non-empty OSVersion")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
if fp.KiroVersion == "" {
t.Error("expected non-empty KiroVersion")
}
if fp.KiroHash == "" {
t.Error("expected non-empty KiroHash")
}
if fp.AcceptLanguage == "" {
t.Error("expected non-empty AcceptLanguage")
}
if fp.ScreenResolution == "" {
t.Error("expected non-empty ScreenResolution")
}
if fp.ColorDepth == 0 {
t.Error("expected non-zero ColorDepth")
}
if fp.HardwareConcurrency == 0 {
t.Error("expected non-zero HardwareConcurrency")
}
}
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1 != fp2 {
t.Error("expected same fingerprint for same token")
}
}
func TestGetFingerprint_DifferentTokens(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token2")
if fp1 == fp2 {
t.Error("expected different fingerprints for different tokens")
}
}
func TestRemoveFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fm.GetFingerprint("token1")
if fm.Count() != 1 {
t.Fatalf("expected count 1, got %d", fm.Count())
}
fm.RemoveFingerprint("token1")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestRemoveFingerprint_NonExistent(t *testing.T) {
fm := NewFingerprintManager()
fm.RemoveFingerprint("nonexistent")
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
}
func TestCount(t *testing.T) {
fm := NewFingerprintManager()
if fm.Count() != 0 {
t.Errorf("expected count 0, got %d", fm.Count())
}
fm.GetFingerprint("token1")
fm.GetFingerprint("token2")
fm.GetFingerprint("token3")
if fm.Count() != 3 {
t.Errorf("expected count 3, got %d", fm.Count())
}
}
func TestApplyToRequest(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
t.Error("X-Kiro-SDK-Version header mismatch")
}
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
t.Error("X-Kiro-OS-Type header mismatch")
}
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
t.Error("X-Kiro-OS-Version header mismatch")
}
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
t.Error("X-Kiro-Node-Version header mismatch")
}
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
t.Error("X-Kiro-Version header mismatch")
}
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
t.Error("X-Kiro-Hash header mismatch")
}
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
t.Error("Accept-Language header mismatch")
}
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
t.Error("X-Screen-Resolution header mismatch")
}
}
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
fm := NewFingerprintManager()
for i := 0; i < 20; i++ {
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
validVersions := osVersions[fp.OSType]
found := false
for _, v := range validVersions {
if v == fp.OSVersion {
found = true
break
}
}
if !found {
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
}
}
}
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
fm := NewFingerprintManager()
const numGoroutines = 100
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
tokenKey := "token" + string(rune('a'+id%26))
switch j % 4 {
case 0:
fm.GetFingerprint(tokenKey)
case 1:
fm.Count()
case 2:
fp := fm.GetFingerprint(tokenKey)
req, _ := http.NewRequest("GET", "http://example.com", nil)
fp.ApplyToRequest(req)
case 3:
fm.RemoveFingerprint(tokenKey)
}
}
}(i)
}
wg.Wait()
}
func TestKiroHashUniqueness(t *testing.T) {
fm := NewFingerprintManager()
hashes := make(map[string]bool)
for i := 0; i < 100; i++ {
fp := fm.GetFingerprint("token" + string(rune(i)))
if hashes[fp.KiroHash] {
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
}
hashes[fp.KiroHash] = true
}
}
func TestKiroHashFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if len(fp.KiroHash) != 64 {
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
}
for _, c := range fp.KiroHash {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("invalid hex character in KiroHash: %c", c)
}
}
}

View File

@@ -0,0 +1,174 @@
package kiro
import (
"math/rand"
"sync"
"time"
)
// Jitter configuration constants
const (
// JitterPercent is the default percentage of jitter to apply (±30%)
JitterPercent = 0.30
// Human-like delay ranges
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
LongDelayMin = 5 * time.Second // Minimum for reading/resting
LongDelayMax = 10 * time.Second // Maximum for reading/resting
// Probability thresholds for human-like behavior
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
)
var (
jitterRand *rand.Rand
jitterRandOnce sync.Once
jitterMu sync.Mutex
lastRequestTime time.Time
)
// initJitterRand initializes the random number generator for jitter calculations.
// Uses a time-based seed for unpredictable but reproducible randomness.
func initJitterRand() {
jitterRandOnce.Do(func() {
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
})
}
// RandomDelay generates a random delay between min and max duration.
// Thread-safe implementation using mutex protection.
func RandomDelay(min, max time.Duration) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if min >= max {
return min
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// JitterDelay adds jitter to a base delay.
// Applies ±jitterPercent variation to the base delay.
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if jitterPercent <= 0 || jitterPercent > 1 {
jitterPercent = JitterPercent
}
// Calculate jitter range: base * jitterPercent
jitterRange := float64(baseDelay) * jitterPercent
// Generate random value in range [-jitterRange, +jitterRange]
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
result := time.Duration(float64(baseDelay) + jitter)
if result < 0 {
return 0
}
return result
}
// JitterDelayDefault applies the default ±30% jitter to a base delay.
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
return JitterDelay(baseDelay, JitterPercent)
}
// HumanLikeDelay generates a delay that mimics human behavior patterns.
// The delay is selected based on probability distribution:
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
//
// Returns the delay duration (caller should call time.Sleep with this value).
func HumanLikeDelay() time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
// Track time since last request for adaptive behavior
now := time.Now()
timeSinceLastRequest := now.Sub(lastRequestTime)
lastRequestTime = now
// If requests are very close together, use short delay
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
}
// Otherwise, use probability-based selection
roll := jitterRand.Float64()
var min, max time.Duration
switch {
case roll < ShortDelayProbability:
// Short delay - consecutive operations
min, max = ShortDelayMin, ShortDelayMax
case roll < ShortDelayProbability+LongDelayProbability:
// Long delay - reading/resting
min, max = LongDelayMin, LongDelayMax
default:
// Normal delay - thinking time
min, max = NormalDelayMin, NormalDelayMax
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// ApplyHumanLikeDelay applies human-like delay by sleeping.
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
func ApplyHumanLikeDelay() {
delay := HumanLikeDelay()
if delay > 0 {
time.Sleep(delay)
}
}
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
// Calculate exponential backoff: baseDelay * 2^attempt
backoff := baseDelay * time.Duration(1<<uint(attempt))
if backoff > maxDelay {
backoff = maxDelay
}
// Add ±30% jitter
return JitterDelay(backoff, JitterPercent)
}
// ShouldSkipDelay determines if delay should be skipped based on context.
// Returns true for streaming responses, WebSocket connections, etc.
// This function can be extended to check additional skip conditions.
func ShouldSkipDelay(isStreaming bool) bool {
return isStreaming
}
// ResetLastRequestTime resets the last request time tracker.
// Useful for testing or when starting a new session.
func ResetLastRequestTime() {
jitterMu.Lock()
defer jitterMu.Unlock()
lastRequestTime = time.Time{}
}

View File

@@ -0,0 +1,187 @@
package kiro
import (
"math"
"sync"
"time"
)
// TokenMetrics holds performance metrics for a single token.
type TokenMetrics struct {
SuccessRate float64 // Success rate (0.0 - 1.0)
AvgLatency float64 // Average latency in milliseconds
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
LastUsed time.Time // Last usage timestamp
FailCount int // Consecutive failure count
TotalRequests int // Total request count
successCount int // Internal: successful request count
totalLatency float64 // Internal: cumulative latency
}
// TokenScorer manages token metrics and scoring.
type TokenScorer struct {
mu sync.RWMutex
metrics map[string]*TokenMetrics
// Scoring weights
successRateWeight float64
quotaWeight float64
latencyWeight float64
lastUsedWeight float64
failPenaltyMultiplier float64
}
// NewTokenScorer creates a new TokenScorer with default weights.
func NewTokenScorer() *TokenScorer {
return &TokenScorer{
metrics: make(map[string]*TokenMetrics),
successRateWeight: 0.4,
quotaWeight: 0.25,
latencyWeight: 0.2,
lastUsedWeight: 0.15,
failPenaltyMultiplier: 0.1,
}
}
// getOrCreateMetrics returns existing metrics or creates new ones.
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
if m, ok := s.metrics[tokenKey]; ok {
return m
}
m := &TokenMetrics{
SuccessRate: 1.0,
QuotaRemaining: 1.0,
}
s.metrics[tokenKey] = m
return m
}
// RecordRequest records the result of a request for a token.
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.TotalRequests++
m.LastUsed = time.Now()
m.totalLatency += float64(latency.Milliseconds())
if success {
m.successCount++
m.FailCount = 0
} else {
m.FailCount++
}
// Update derived metrics
if m.TotalRequests > 0 {
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
}
}
// SetQuotaRemaining updates the remaining quota for a token.
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.QuotaRemaining = quota
}
// GetMetrics returns a copy of the metrics for a token.
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
s.mu.RLock()
defer s.mu.RUnlock()
if m, ok := s.metrics[tokenKey]; ok {
copy := *m
return &copy
}
return nil
}
// CalculateScore computes the score for a token (higher is better).
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
s.mu.RLock()
defer s.mu.RUnlock()
m, ok := s.metrics[tokenKey]
if !ok {
return 1.0 // New tokens get a high initial score
}
// Success rate component (0-1)
successScore := m.SuccessRate
// Quota component (0-1)
quotaScore := m.QuotaRemaining
// Latency component (normalized, lower is better)
// Using exponential decay: score = e^(-latency/1000)
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
if m.TotalRequests == 0 {
latencyScore = 1.0
}
// Last used component (prefer tokens not recently used)
// Score increases as time since last use increases
timeSinceUse := time.Since(m.LastUsed).Seconds()
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
if m.LastUsed.IsZero() {
lastUsedScore = 1.0
}
// Calculate weighted score
score := s.successRateWeight*successScore +
s.quotaWeight*quotaScore +
s.latencyWeight*latencyScore +
s.lastUsedWeight*lastUsedScore
// Apply consecutive failure penalty
if m.FailCount > 0 {
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
score = score * math.Max(0, 1.0-penalty)
}
return score
}
// SelectBestToken selects the token with the highest score.
func (s *TokenScorer) SelectBestToken(tokens []string) string {
if len(tokens) == 0 {
return ""
}
if len(tokens) == 1 {
return tokens[0]
}
bestToken := tokens[0]
bestScore := s.CalculateScore(tokens[0])
for _, token := range tokens[1:] {
score := s.CalculateScore(token)
if score > bestScore {
bestScore = score
bestToken = token
}
}
return bestToken
}
// ResetMetrics clears all metrics for a token.
func (s *TokenScorer) ResetMetrics(tokenKey string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.metrics, tokenKey)
}
// ResetAllMetrics clears all stored metrics.
func (s *TokenScorer) ResetAllMetrics() {
s.mu.Lock()
defer s.mu.Unlock()
s.metrics = make(map[string]*TokenMetrics)
}

View File

@@ -0,0 +1,301 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewTokenScorer(t *testing.T) {
s := NewTokenScorer()
if s == nil {
t.Fatal("expected non-nil TokenScorer")
}
if s.metrics == nil {
t.Error("expected non-nil metrics map")
}
if s.successRateWeight != 0.4 {
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
}
if s.quotaWeight != 0.25 {
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
}
}
func TestRecordRequest_Success(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m == nil {
t.Fatal("expected non-nil metrics")
}
if m.TotalRequests != 1 {
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
}
if m.SuccessRate != 1.0 {
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", m.FailCount)
}
if m.AvgLatency != 100 {
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
}
}
func TestRecordRequest_Failure(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", false, 200*time.Millisecond)
m := s.GetMetrics("token1")
if m.SuccessRate != 0.0 {
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
}
if m.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", m.FailCount)
}
}
func TestRecordRequest_MixedResults(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.TotalRequests != 4 {
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
}
if m.SuccessRate != 0.75 {
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
}
}
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.FailCount != 3 {
t.Errorf("expected FailCount 3, got %d", m.FailCount)
}
}
func TestSetQuotaRemaining(t *testing.T) {
s := NewTokenScorer()
s.SetQuotaRemaining("token1", 0.5)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 0.5 {
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
}
}
func TestGetMetrics_NonExistent(t *testing.T) {
s := NewTokenScorer()
m := s.GetMetrics("nonexistent")
if m != nil {
t.Error("expected nil metrics for non-existent token")
}
}
func TestGetMetrics_ReturnsCopy(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m1 := s.GetMetrics("token1")
m1.TotalRequests = 999
m2 := s.GetMetrics("token1")
if m2.TotalRequests == 999 {
t.Error("GetMetrics should return a copy")
}
}
func TestCalculateScore_NewToken(t *testing.T) {
s := NewTokenScorer()
score := s.CalculateScore("newtoken")
if score != 1.0 {
t.Errorf("expected score 1.0 for new token, got %f", score)
}
}
func TestCalculateScore_PerfectToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 50*time.Millisecond)
s.SetQuotaRemaining("token1", 1.0)
time.Sleep(100 * time.Millisecond)
score := s.CalculateScore("token1")
if score < 0.5 || score > 1.0 {
t.Errorf("expected high score for perfect token, got %f", score)
}
}
func TestCalculateScore_FailedToken(t *testing.T) {
s := NewTokenScorer()
for i := 0; i < 5; i++ {
s.RecordRequest("token1", false, 1000*time.Millisecond)
}
s.SetQuotaRemaining("token1", 0.1)
score := s.CalculateScore("token1")
if score > 0.5 {
t.Errorf("expected low score for failed token, got %f", score)
}
}
func TestCalculateScore_FailPenalty(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
scoreNoFail := s.CalculateScore("token1")
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
scoreWithFail := s.CalculateScore("token1")
if scoreWithFail >= scoreNoFail {
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
}
}
func TestSelectBestToken_Empty(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{})
if best != "" {
t.Errorf("expected empty string for empty tokens, got %s", best)
}
}
func TestSelectBestToken_SingleToken(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{"token1"})
if best != "token1" {
t.Errorf("expected token1, got %s", best)
}
}
func TestSelectBestToken_MultipleTokens(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.SetQuotaRemaining("bad", 0.1)
s.RecordRequest("good", true, 50*time.Millisecond)
s.SetQuotaRemaining("good", 0.9)
time.Sleep(50 * time.Millisecond)
best := s.SelectBestToken([]string{"bad", "good"})
if best != "good" {
t.Errorf("expected good token to be selected, got %s", best)
}
}
func TestResetMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.ResetMetrics("token1")
m := s.GetMetrics("token1")
if m != nil {
t.Error("expected nil metrics after reset")
}
}
func TestResetAllMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token2", true, 100*time.Millisecond)
s.RecordRequest("token3", true, 100*time.Millisecond)
s.ResetAllMetrics()
if s.GetMetrics("token1") != nil {
t.Error("expected nil metrics for token1 after reset all")
}
if s.GetMetrics("token2") != nil {
t.Error("expected nil metrics for token2 after reset all")
}
}
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
s := NewTokenScorer()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
case 1:
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
case 2:
s.GetMetrics(tokenKey)
case 3:
s.CalculateScore(tokenKey)
case 4:
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
case 5:
if j%20 == 0 {
s.ResetMetrics(tokenKey)
}
}
}
}(i)
}
wg.Wait()
}
func TestAvgLatencyCalculation(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 200*time.Millisecond)
s.RecordRequest("token1", true, 300*time.Millisecond)
m := s.GetMetrics("token1")
if m.AvgLatency != 200 {
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
}
}
func TestLastUsedUpdated(t *testing.T) {
s := NewTokenScorer()
before := time.Now()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.LastUsed.Before(before) {
t.Error("expected LastUsed to be after test start time")
}
if m.LastUsed.After(time.Now()) {
t.Error("expected LastUsed to be before or equal to now")
}
}
func TestDefaultQuotaForNewToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 1.0 {
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
}
}

View File

@@ -190,7 +190,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -227,11 +227,19 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// RefreshToken refreshes an expired access token.
// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
}
// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
// tokenKey is used to generate a consistent fingerprint for the token.
func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
@@ -248,7 +256,11 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
// This helps avoid 403 errors from server-side User-Agent validation
userAgent := buildKiroUserAgent(tokenKey)
req.Header.Set("User-Agent", userAgent)
resp, err := o.httpClient.Do(req)
if err != nil {
@@ -263,7 +275,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
}
var tokenResp KiroTokenResponse
@@ -285,9 +297,23 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
// Otherwise generates a simple KiroIDE User-Agent.
func buildKiroUserAgent(tokenKey string) string {
if tokenKey != "" {
fm := NewFingerprintManager()
fp := fm.GetFingerprint(tokenKey)
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
}
// Default KiroIDE User-Agent matching kiro-openai-gateway format
return "KiroIDE-0.7.45-cli-proxy-api"
}
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {

View File

@@ -0,0 +1,969 @@
// Package kiro provides OAuth Web authentication for Kiro.
package kiro
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
defaultSessionExpiry = 10 * time.Minute
pollIntervalSeconds = 5
)
type authSessionStatus string
const (
statusPending authSessionStatus = "pending"
statusSuccess authSessionStatus = "success"
statusFailed authSessionStatus = "failed"
)
type webAuthSession struct {
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
}
type OAuthWebHandler struct {
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
}
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
return &OAuthWebHandler{
cfg: cfg,
sessions: make(map[string]*webAuthSession),
}
}
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
h.onTokenObtained = callback
}
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
oauth := router.Group("/v0/oauth/kiro")
{
oauth.GET("", h.handleSelect)
oauth.GET("/start", h.handleStart)
oauth.GET("/callback", h.handleCallback)
oauth.GET("/social/callback", h.handleSocialCallback)
oauth.GET("/status", h.handleStatus)
oauth.POST("/import", h.handleImportToken)
oauth.POST("/refresh", h.handleManualRefresh)
}
}
func generateStateID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
h.renderSelectPage(c)
}
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
method := c.Query("method")
if method == "" {
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
return
}
switch method {
case "google", "github":
// Google/GitHub social login is not supported for third-party apps
// due to AWS Cognito redirect_uri restrictions
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
case "builder-id":
h.startBuilderIDAuth(c)
case "idc":
h.startIDCAuth(c)
default:
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
}
}
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
h.renderError(c, "Failed to generate PKCE parameters")
return
}
socialClient := NewSocialAuthClient(h.cfg)
var provider string
if method == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
redirectURI := h.getSocialCallbackURL(c)
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
session := &webAuthSession{
stateID: stateID,
authMethod: method,
authURL: authURL,
status: statusPending,
startedAt: time.Now(),
expiresIn: 600,
codeVerifier: codeVerifier,
codeChallenge: codeChallenge,
region: "us-east-1",
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go func() {
<-ctx.Done()
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
}()
c.Redirect(http.StatusFound, authURL)
}
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
scheme := "http"
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
scheme = "https"
}
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
}
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
region := defaultIDCRegion
startURL := builderIDStartURL
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "builder-id",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
startURL := c.Query("startUrl")
region := c.Query("region")
if startURL == "" {
h.renderError(c, "Missing startUrl parameter for IDC authentication")
return
}
if region == "" {
region = defaultIDCRegion
}
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "idc",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
defer session.cancelFunc()
interval := time.Duration(session.interval) * time.Second
if interval < time.Duration(pollIntervalSeconds)*time.Second {
interval = time.Duration(pollIntervalSeconds) * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
return
case <-ticker.C:
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
ctx,
session.clientID,
session.clientSecret,
session.deviceCode,
session.region,
)
if err != nil {
errStr := err.Error()
if errStr == ErrAuthorizationPending.Error() {
continue
}
if errStr == ErrSlowDown.Error() {
interval += 5 * time.Second
ticker.Reset(interval)
continue
}
h.mu.Lock()
session.status = statusFailed
session.error = errStr
session.completedAt = time.Now()
h.mu.Unlock()
log.Errorf("OAuth Web: token polling failed: %v", err)
return
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: authentication successful for %s", email)
return
}
}
}
// saveTokenToFile saves the token data to the auth directory
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
// Get auth directory from config or use default
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
// Fall back to default location
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
log.Errorf("OAuth Web: failed to get home directory: %v", err)
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Create directory if not exists
if err := os.MkdirAll(authDir, 0700); err != nil {
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
return
}
// Generate filename using the unified function
fileName := GenerateTokenFileName(tokenData)
authFilePath := filepath.Join(authDir, fileName)
// Convert to storage format and save
storage := &KiroTokenStorage{
Type: "kiro",
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
ProfileArn: tokenData.ProfileArn,
ExpiresAt: tokenData.ExpiresAt,
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
if err := storage.SaveTokenToFile(authFilePath); err != nil {
log.Errorf("OAuth Web: failed to save token to file: %v", err)
return
}
log.Infof("OAuth Web: token saved to %s", authFilePath)
}
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
return session.ssoClient
}
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
stateID := c.Query("state")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.status == statusSuccess {
h.renderSuccess(c, session)
} else if session.status == statusFailed {
h.renderError(c, session.error)
} else {
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
}
}
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
stateID := c.Query("state")
code := c.Query("code")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
if code == "" {
h.renderError(c, "Missing authorization code")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.authMethod != "google" && session.authMethod != "github" {
h.renderError(c, "Invalid session type for social callback")
return
}
socialClient := NewSocialAuthClient(h.cfg)
redirectURI := h.getSocialCallbackURL(c)
tokenReq := &CreateTokenRequest{
Code: code,
CodeVerifier: session.codeVerifier,
RedirectURI: redirectURI,
}
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
if err != nil {
log.Errorf("OAuth Web: social token exchange failed: %v", err)
h.mu.Lock()
session.status = statusFailed
session.error = fmt.Sprintf("Token exchange failed: %v", err)
session.completedAt = time.Now()
h.mu.Unlock()
h.renderError(c, session.error)
return
}
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
email := ExtractEmailFromJWT(tokenResp.AccessToken)
var provider string
if session.authMethod == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: provider,
Email: email,
Region: "us-east-1",
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if session.cancelFunc != nil {
session.cancelFunc()
}
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
h.renderSuccess(c, session)
}
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
stateID := c.Query("state")
if stateID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
response := gin.H{
"status": string(session.status),
}
switch session.status {
case statusPending:
elapsed := time.Since(session.startedAt).Seconds()
remaining := float64(session.expiresIn) - elapsed
if remaining < 0 {
remaining = 0
}
response["remaining_seconds"] = int(remaining)
case statusSuccess:
response["completed_at"] = session.completedAt.Format(time.RFC3339)
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
case statusFailed:
response["error"] = session.error
response["failed_at"] = session.completedAt.Format(time.RFC3339)
}
c.JSON(http.StatusOK, response)
}
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"AuthURL": session.authURL,
"UserCode": session.userCode,
"ExpiresIn": session.expiresIn,
"StateID": session.stateID,
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render template: %v", err)
}
}
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse select template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, nil); err != nil {
log.Errorf("OAuth Web: failed to render select template: %v", err)
}
}
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse error template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"Error": errMsg,
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.Status(http.StatusBadRequest)
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render error template: %v", err)
}
}
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse success template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render success template: %v", err)
}
}
func (h *OAuthWebHandler) CleanupExpiredSessions() {
h.mu.Lock()
defer h.mu.Unlock()
now := time.Now()
for id, session := range h.sessions {
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
delete(h.sessions, id)
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
session.cancelFunc()
delete(h.sessions, id)
}
}
}
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
session, exists := h.sessions[stateID]
return session, exists
}
// ImportTokenRequest represents the request body for token import
type ImportTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// handleImportToken handles manual refresh token import from Kiro IDE
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
var req ImportTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid request body",
})
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Refresh token is required",
})
return
}
// Validate token format
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid token format. Token should start with aorAAAAAG...",
})
return
}
// Create social auth client to refresh and validate the token
socialClient := NewSocialAuthClient(h.cfg)
// Refresh the token to validate it and get access token
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
if err != nil {
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("Token validation failed: %v", err),
})
return
}
// Set the original refresh token (the refreshed one might be empty)
if tokenData.RefreshToken == "" {
tokenData.RefreshToken = refreshToken
}
tokenData.AuthMethod = "social"
tokenData.Provider = "imported"
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
// Generate filename for response using the unified function
fileName := GenerateTokenFileName(tokenData)
log.Infof("OAuth Web: token imported successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token imported successfully",
"fileName": fileName,
})
}
// handleManualRefresh handles manual token refresh requests from the web UI.
// This allows users to trigger a token refresh when needed, without waiting
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": "Failed to get home directory",
})
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Find all kiro token files in the auth directory
files, err := os.ReadDir(authDir)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
})
return
}
var refreshedCount int
var errors []string
for _, file := range files {
if file.IsDir() {
continue
}
name := file.Name()
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
continue
}
filePath := filepath.Join(authDir, name)
data, err := os.ReadFile(filePath)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
continue
}
var storage KiroTokenStorage
if err := json.Unmarshal(data, &storage); err != nil {
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
continue
}
if storage.RefreshToken == "" {
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
continue
}
// Refresh token using the same logic as kiro_executor.Refresh
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
continue
}
// Update storage with new token data
storage.AccessToken = tokenData.AccessToken
if tokenData.RefreshToken != "" {
storage.RefreshToken = tokenData.RefreshToken
}
storage.ExpiresAt = tokenData.ExpiresAt
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ProfileArn != "" {
storage.ProfileArn = tokenData.ProfileArn
}
// Write updated token back to file
updatedData, err := json.MarshalIndent(storage, "", " ")
if err != nil {
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
continue
}
tmpFile := filePath + ".tmp"
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
continue
}
if err := os.Rename(tmpFile, filePath); err != nil {
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
continue
}
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
refreshedCount++
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
}
if refreshedCount == 0 && len(errors) > 0 {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
})
return
}
response := gin.H{
"success": true,
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
"refreshedCount": refreshedCount,
}
if len(errors) > 0 {
response["warnings"] = errors
}
c.JSON(http.StatusOK, response)
}
// refreshTokenData refreshes a token using the appropriate method based on auth type.
// This mirrors the logic in kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(h.cfg)
switch {
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
// IDC refresh with region-specific endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
// Builder ID refresh with default endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
default:
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
oauth := NewKiroOAuth(h.cfg)
return oauth.RefreshToken(ctx, storage.RefreshToken)
}
}

View File

@@ -0,0 +1,779 @@
// Package kiro provides OAuth Web authentication templates.
package kiro
const (
oauthWebStartPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AWS SSO Authentication</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.step {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-bottom: 15px;
}
.step-title {
display: flex;
align-items: center;
font-weight: 600;
color: #333;
margin-bottom: 10px;
}
.step-number {
width: 28px;
height: 28px;
background: #667eea;
color: white;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 14px;
margin-right: 12px;
}
.user-code {
background: #e7f3ff;
border: 2px dashed #2196F3;
border-radius: 8px;
padding: 20px;
text-align: center;
margin-top: 10px;
}
.user-code-label {
font-size: 12px;
color: #666;
text-transform: uppercase;
letter-spacing: 1px;
margin-bottom: 8px;
}
.user-code-value {
font-size: 32px;
font-weight: bold;
font-family: monospace;
color: #2196F3;
letter-spacing: 4px;
}
.auth-btn {
display: block;
width: 100%;
padding: 15px;
background: #667eea;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
margin-top: 20px;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.status {
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
text-align: center;
}
.status-pending { border-left: 4px solid #ffc107; }
.status-success { border-left: 4px solid #28a745; }
.status-failed { border-left: 4px solid #dc3545; }
.spinner {
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
width: 40px;
height: 40px;
animation: spin 1s linear infinite;
margin: 0 auto 15px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.timer {
font-size: 24px;
font-weight: bold;
color: #667eea;
margin: 10px 0;
}
.timer.warning { color: #ffc107; }
.timer.danger { color: #dc3545; }
.status-message { color: #666; line-height: 1.6; }
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 AWS SSO Authentication</h1>
<p class="subtitle">Follow the steps below to complete authentication</p>
<div class="step">
<div class="step-title">
<span class="step-number">1</span>
Click the button below to open the authorization page
</div>
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
🚀 Open Authorization Page
</a>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">2</span>
Enter the verification code below
</div>
<div class="user-code">
<div class="user-code-label">Verification Code</div>
<div class="user-code-value">{{.UserCode}}</div>
</div>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">3</span>
Complete AWS SSO login
</div>
<p style="color: #666; font-size: 14px; margin-top: 10px;">
Use your AWS SSO account to login and authorize
</p>
</div>
<div class="status status-pending" id="statusBox">
<div class="spinner" id="spinner"></div>
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
<div class="status-message" id="statusMessage">
Waiting for authorization...
</div>
</div>
<div class="info-box">
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
</div>
</div>
<script>
let pollInterval;
let timerInterval;
let remainingSeconds = {{.ExpiresIn}};
const stateID = "{{.StateID}}";
setTimeout(() => {
document.getElementById('authBtn').click();
}, 500);
function pollStatus() {
fetch('/v0/oauth/kiro/status?state=' + stateID)
.then(response => response.json())
.then(data => {
console.log('Status:', data);
if (data.status === 'success') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showSuccess(data);
} else if (data.status === 'failed') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showError(data);
} else {
remainingSeconds = data.remaining_seconds || 0;
}
})
.catch(error => {
console.error('Poll error:', error);
});
}
function updateTimer() {
const timerEl = document.getElementById('timer');
const minutes = Math.floor(remainingSeconds / 60);
const seconds = remainingSeconds % 60;
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
if (remainingSeconds < 60) {
timerEl.className = 'timer danger';
} else if (remainingSeconds < 180) {
timerEl.className = 'timer warning';
} else {
timerEl.className = 'timer';
}
remainingSeconds--;
if (remainingSeconds < 0) {
clearInterval(timerInterval);
clearInterval(pollInterval);
showError({ error: 'Authentication timed out. Please refresh and try again.' });
}
}
function showSuccess(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-success';
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
'<div class="status-message">' +
'<strong>Authentication Successful!</strong><br>' +
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
'</div>';
}
function showError(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-failed';
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
'<div class="status-message">' +
'<strong>Authentication Failed</strong><br>' +
(data.error || 'Unknown error') +
'</div>' +
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
'🔄 Retry' +
'</button>';
}
pollInterval = setInterval(pollStatus, 3000);
timerInterval = setInterval(updateTimer, 1000);
pollStatus();
</script>
</body>
</html>`
oauthWebErrorPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Failed</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.error {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #dc3545;
}
h1 { color: #dc3545; margin-top: 0; }
.error-message { color: #666; line-height: 1.6; }
.retry-btn {
display: inline-block;
margin-top: 20px;
padding: 10px 20px;
background: #007bff;
color: white;
text-decoration: none;
border-radius: 4px;
}
.retry-btn:hover { background: #0056b3; }
</style>
</head>
<body>
<div class="error">
<h1>❌ Authentication Failed</h1>
<div class="error-message">
<p><strong>Error:</strong></p>
<p>{{.Error}}</p>
</div>
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
</div>
</body>
</html>`
oauthWebSuccessPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.success {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #28a745;
text-align: center;
}
h1 { color: #28a745; margin-top: 0; }
.success-message { color: #666; line-height: 1.6; }
.icon { font-size: 48px; margin-bottom: 15px; }
.expires { font-size: 14px; color: #999; margin-top: 15px; }
</style>
</head>
<body>
<div class="success">
<div class="icon">✅</div>
<h1>Authentication Successful!</h1>
<div class="success-message">
<p>You can close this window.</p>
</div>
<div class="expires">Token expires: {{.ExpiresAt}}</div>
</div>
</body>
</html>`
oauthWebSelectPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Select Authentication Method</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.auth-methods {
display: flex;
flex-direction: column;
gap: 15px;
}
.auth-btn {
display: flex;
align-items: center;
width: 100%;
padding: 15px 20px;
background: #667eea;
color: white;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.auth-btn .icon {
font-size: 24px;
margin-right: 15px;
width: 32px;
text-align: center;
}
.auth-btn.google { background: #4285F4; }
.auth-btn.google:hover { background: #3367D6; }
.auth-btn.github { background: #24292e; }
.auth-btn.github:hover { background: #1a1e22; }
.auth-btn.aws { background: #FF9900; }
.auth-btn.aws:hover { background: #E68A00; }
.auth-btn.idc { background: #232F3E; }
.auth-btn.idc:hover { background: #1a242f; }
.idc-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.idc-form.show {
display: block;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
font-weight: 600;
color: #333;
margin-bottom: 8px;
font-size: 14px;
}
.form-group input {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.form-group .hint {
font-size: 12px;
color: #999;
margin-top: 5px;
}
.submit-btn {
display: block;
width: 100%;
padding: 15px;
background: #232F3E;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.submit-btn:hover {
background: #1a242f;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
}
.divider {
display: flex;
align-items: center;
margin: 20px 0;
}
.divider::before,
.divider::after {
content: "";
flex: 1;
border-bottom: 1px solid #e0e0e0;
}
.divider span {
padding: 0 15px;
color: #999;
font-size: 14px;
}
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
.warning-box {
background: #fff3cd;
border-left: 4px solid #ffc107;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #856404;
}
.auth-btn.manual { background: #6c757d; }
.auth-btn.manual:hover { background: #5a6268; }
.auth-btn.refresh { background: #17a2b8; }
.auth-btn.refresh:hover { background: #138496; }
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
.manual-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.manual-form.show {
display: block;
}
.form-group textarea {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
font-family: monospace;
transition: border-color 0.3s;
resize: vertical;
min-height: 80px;
}
.form-group textarea:focus {
outline: none;
border-color: #667eea;
}
.status-message {
padding: 15px;
border-radius: 6px;
margin-top: 15px;
display: none;
}
.status-message.success {
background: #d4edda;
color: #155724;
display: block;
}
.status-message.error {
background: #f8d7da;
color: #721c24;
display: block;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Select Authentication Method</h1>
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
<div class="auth-methods">
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
<span class="icon">🔶</span>
AWS Builder ID (Recommended)
</a>
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
<span class="icon">🏢</span>
AWS Identity Center (IDC)
</button>
<div class="divider"><span>or</span></div>
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
<span class="icon">📋</span>
Import RefreshToken from Kiro IDE
</button>
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
<span class="icon">🔄</span>
Manual Refresh All Tokens
</button>
<div class="status-message" id="refreshStatus"></div>
</div>
<div class="idc-form" id="idcForm">
<form action="/v0/oauth/kiro/start" method="get">
<input type="hidden" name="method" value="idc">
<div class="form-group">
<label for="startUrl">Start URL</label>
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
<div class="hint">Your AWS Identity Center Start URL</div>
</div>
<div class="form-group">
<label for="region">Region</label>
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
<div class="hint">AWS Region for your Identity Center</div>
</div>
<button type="submit" class="submit-btn">
🚀 Continue with IDC
</button>
</form>
</div>
<div class="manual-form" id="manualForm">
<form id="importForm" onsubmit="submitImport(event)">
<div class="form-group">
<label for="refreshToken">Refresh Token</label>
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
</div>
<button type="submit" class="submit-btn" id="importBtn">
📥 Import Token
</button>
<div class="status-message" id="importStatus"></div>
</form>
</div>
<div class="warning-box">
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
</div>
<div class="info-box">
💡 <strong>How to get RefreshToken:</strong><br>
1. Open Kiro IDE and login with Google/GitHub<br>
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
3. Copy the <code>refreshToken</code> value and paste it above
</div>
</div>
<script>
function toggleIdcForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
manualForm.classList.remove('show');
idcForm.classList.toggle('show');
if (idcForm.classList.contains('show')) {
document.getElementById('startUrl').focus();
}
}
function toggleManualForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
idcForm.classList.remove('show');
manualForm.classList.toggle('show');
if (manualForm.classList.contains('show')) {
document.getElementById('refreshToken').focus();
}
}
async function submitImport(event) {
event.preventDefault();
const refreshToken = document.getElementById('refreshToken').value.trim();
const statusEl = document.getElementById('importStatus');
const btn = document.getElementById('importBtn');
if (!refreshToken) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Please enter a refresh token';
return;
}
if (!refreshToken.startsWith('aorAAAAAG')) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
return;
}
btn.disabled = true;
btn.textContent = '⏳ Importing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/import', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refreshToken: refreshToken })
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.textContent = '📥 Import Token';
}
}
async function manualRefresh() {
const btn = document.getElementById('refreshBtn');
const statusEl = document.getElementById('refreshStatus');
btn.disabled = true;
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
let msg = '✅ ' + data.message;
if (data.warnings && data.warnings.length > 0) {
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
}
statusEl.textContent = msg;
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
}
}
</script>
</body>
</html>`
)

View File

@@ -0,0 +1,316 @@
package kiro
import (
"math"
"math/rand"
"strings"
"sync"
"time"
)
const (
DefaultMinTokenInterval = 1 * time.Second
DefaultMaxTokenInterval = 2 * time.Second
DefaultDailyMaxRequests = 500
DefaultJitterPercent = 0.3
DefaultBackoffBase = 30 * time.Second
DefaultBackoffMax = 5 * time.Minute
DefaultBackoffMultiplier = 1.5
DefaultSuspendCooldown = 1 * time.Hour
)
// TokenState Token 状态
type TokenState struct {
LastRequest time.Time
RequestCount int
CooldownEnd time.Time
FailCount int
DailyRequests int
DailyResetTime time.Time
IsSuspended bool
SuspendedAt time.Time
SuspendReason string
}
// RateLimiter 频率限制器
type RateLimiter struct {
mu sync.RWMutex
states map[string]*TokenState
minTokenInterval time.Duration
maxTokenInterval time.Duration
dailyMaxRequests int
jitterPercent float64
backoffBase time.Duration
backoffMax time.Duration
backoffMultiplier float64
suspendCooldown time.Duration
rng *rand.Rand
}
// NewRateLimiter 创建默认配置的频率限制器
func NewRateLimiter() *RateLimiter {
return &RateLimiter{
states: make(map[string]*TokenState),
minTokenInterval: DefaultMinTokenInterval,
maxTokenInterval: DefaultMaxTokenInterval,
dailyMaxRequests: DefaultDailyMaxRequests,
jitterPercent: DefaultJitterPercent,
backoffBase: DefaultBackoffBase,
backoffMax: DefaultBackoffMax,
backoffMultiplier: DefaultBackoffMultiplier,
suspendCooldown: DefaultSuspendCooldown,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// RateLimiterConfig 频率限制器配置
type RateLimiterConfig struct {
MinTokenInterval time.Duration
MaxTokenInterval time.Duration
DailyMaxRequests int
JitterPercent float64
BackoffBase time.Duration
BackoffMax time.Duration
BackoffMultiplier float64
SuspendCooldown time.Duration
}
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
rl := NewRateLimiter()
if cfg.MinTokenInterval > 0 {
rl.minTokenInterval = cfg.MinTokenInterval
}
if cfg.MaxTokenInterval > 0 {
rl.maxTokenInterval = cfg.MaxTokenInterval
}
if cfg.DailyMaxRequests > 0 {
rl.dailyMaxRequests = cfg.DailyMaxRequests
}
if cfg.JitterPercent > 0 {
rl.jitterPercent = cfg.JitterPercent
}
if cfg.BackoffBase > 0 {
rl.backoffBase = cfg.BackoffBase
}
if cfg.BackoffMax > 0 {
rl.backoffMax = cfg.BackoffMax
}
if cfg.BackoffMultiplier > 0 {
rl.backoffMultiplier = cfg.BackoffMultiplier
}
if cfg.SuspendCooldown > 0 {
rl.suspendCooldown = cfg.SuspendCooldown
}
return rl
}
// getOrCreateState 获取或创建 Token 状态
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
state, exists := rl.states[tokenKey]
if !exists {
state = &TokenState{
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
}
rl.states[tokenKey] = state
}
return state
}
// resetDailyIfNeeded 如果需要则重置每日计数
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
now := time.Now()
if now.After(state.DailyResetTime) {
state.DailyRequests = 0
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
}
}
// calculateInterval 计算带抖动的随机间隔
func (rl *RateLimiter) calculateInterval() time.Duration {
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
return baseInterval + jitter
}
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
func (rl *RateLimiter) WaitForToken(tokenKey string) {
rl.mu.Lock()
state := rl.getOrCreateState(tokenKey)
rl.resetDailyIfNeeded(state)
now := time.Now()
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
waitTime := state.CooldownEnd.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
now = time.Now()
}
// 计算距离上次请求的间隔
interval := rl.calculateInterval()
nextAllowedTime := state.LastRequest.Add(interval)
if now.Before(nextAllowedTime) {
waitTime := nextAllowedTime.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
}
state.LastRequest = time.Now()
state.RequestCount++
state.DailyRequests++
rl.mu.Unlock()
}
// MarkTokenFailed 标记 Token 失败
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount++
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
}
// MarkTokenSuccess 标记 Token 成功
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount = 0
state.CooldownEnd = time.Time{}
}
// CheckAndMarkSuspended 检测暂停错误并标记
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
suspendKeywords := []string{
"suspended",
"banned",
"disabled",
"account has been",
"access denied",
"rate limit exceeded",
"too many requests",
"quota exceeded",
}
lowerMsg := strings.ToLower(errorMsg)
for _, keyword := range suspendKeywords {
if strings.Contains(lowerMsg, keyword) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.IsSuspended = true
state.SuspendedAt = time.Now()
state.SuspendReason = errorMsg
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
return true
}
}
return false
}
// IsTokenAvailable 检查 Token 是否可用
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return true
}
now := time.Now()
// 检查是否被暂停
if state.IsSuspended {
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
return true
}
return false
}
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
return false
}
// 检查每日请求限制
rl.mu.RUnlock()
rl.mu.Lock()
rl.resetDailyIfNeeded(state)
dailyRequests := state.DailyRequests
dailyMax := rl.dailyMaxRequests
rl.mu.Unlock()
rl.mu.RLock()
if dailyRequests >= dailyMax {
return false
}
return true
}
// calculateBackoff 计算指数退避时间
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
if failCount <= 0 {
return 0
}
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
// 添加抖动
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
backoff += jitter
if time.Duration(backoff) > rl.backoffMax {
return rl.backoffMax
}
return time.Duration(backoff)
}
// GetTokenState 获取 Token 状态(只读)
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return nil
}
// 返回副本以防止外部修改
stateCopy := *state
return &stateCopy
}
// ClearTokenState 清除 Token 状态
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.states, tokenKey)
}
// ResetSuspension 重置暂停状态
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state, exists := rl.states[tokenKey]
if exists {
state.IsSuspended = false
state.SuspendedAt = time.Time{}
state.SuspendReason = ""
state.CooldownEnd = time.Time{}
state.FailCount = 0
}
}

View File

@@ -0,0 +1,46 @@
package kiro
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
globalRateLimiter *RateLimiter
globalRateLimiterOnce sync.Once
globalCooldownManager *CooldownManager
globalCooldownManagerOnce sync.Once
cooldownStopCh chan struct{}
)
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
func GetGlobalRateLimiter() *RateLimiter {
globalRateLimiterOnce.Do(func() {
globalRateLimiter = NewRateLimiter()
log.Info("kiro: global RateLimiter initialized")
})
return globalRateLimiter
}
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
func GetGlobalCooldownManager() *CooldownManager {
globalCooldownManagerOnce.Do(func() {
globalCooldownManager = NewCooldownManager()
cooldownStopCh = make(chan struct{})
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
log.Info("kiro: global CooldownManager initialized with cleanup routine")
})
return globalCooldownManager
}
// ShutdownRateLimiters stops the cooldown cleanup routine.
// Should be called during application shutdown.
func ShutdownRateLimiters() {
if cooldownStopCh != nil {
close(cooldownStopCh)
log.Info("kiro: rate limiter cleanup routine stopped")
}
}

View File

@@ -0,0 +1,304 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter()
if rl == nil {
t.Fatal("expected non-nil RateLimiter")
}
if rl.states == nil {
t.Error("expected non-nil states map")
}
if rl.minTokenInterval != DefaultMinTokenInterval {
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
}
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
MaxTokenInterval: 15 * time.Second,
DailyMaxRequests: 100,
JitterPercent: 0.2,
BackoffBase: 1 * time.Minute,
BackoffMax: 30 * time.Minute,
BackoffMultiplier: 1.5,
SuspendCooldown: 12 * time.Hour,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != 15*time.Second {
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
}
if rl.dailyMaxRequests != 100 {
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
}
}
func TestGetTokenState_NonExistent(t *testing.T) {
rl := NewRateLimiter()
state := rl.GetTokenState("nonexistent")
if state != nil {
t.Error("expected nil state for non-existent token")
}
}
func TestIsTokenAvailable_NewToken(t *testing.T) {
rl := NewRateLimiter()
if !rl.IsTokenAvailable("newtoken") {
t.Error("expected new token to be available")
}
}
func TestMarkTokenFailed(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", state.FailCount)
}
if state.CooldownEnd.IsZero() {
t.Error("expected non-zero CooldownEnd")
}
}
func TestMarkTokenSuccess(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.MarkTokenFailed("token1")
rl.MarkTokenSuccess("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
if !state.CooldownEnd.IsZero() {
t.Error("expected zero CooldownEnd after success")
}
}
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
rl := NewRateLimiter()
testCases := []string{
"Account has been suspended",
"You are banned from this service",
"Account disabled",
"Access denied permanently",
"Rate limit exceeded",
"Too many requests",
"Quota exceeded for today",
}
for i, msg := range testCases {
tokenKey := "token" + string(rune('a'+i))
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("expected suspension detected for: %s", msg)
}
state := rl.GetTokenState(tokenKey)
if !state.IsSuspended {
t.Errorf("expected IsSuspended true for: %s", msg)
}
}
}
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
rl := NewRateLimiter()
normalErrors := []string{
"connection timeout",
"internal server error",
"bad request",
"invalid token format",
}
for i, msg := range normalErrors {
tokenKey := "token" + string(rune('a'+i))
if rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("unexpected suspension for: %s", msg)
}
}
}
func TestIsTokenAvailable_Suspended(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
if rl.IsTokenAvailable("token1") {
t.Error("expected suspended token to be unavailable")
}
}
func TestClearTokenState(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.ClearTokenState("token1")
state := rl.GetTokenState("token1")
if state != nil {
t.Error("expected nil state after clear")
}
}
func TestResetSuspension(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
rl.ResetSuspension("token1")
state := rl.GetTokenState("token1")
if state.IsSuspended {
t.Error("expected IsSuspended false after reset")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
}
func TestResetSuspension_NonExistent(t *testing.T) {
rl := NewRateLimiter()
rl.ResetSuspension("nonexistent")
}
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
rl := NewRateLimiter()
backoff := rl.calculateBackoff(0)
if backoff != 0 {
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
}
}
func TestCalculateBackoff_Exponential(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 60 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
backoff1 := rl.calculateBackoff(1)
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
}
backoff2 := rl.calculateBackoff(2)
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
}
}
func TestCalculateBackoff_MaxCap(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 10 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0,
}
rl := NewRateLimiterWithConfig(cfg)
backoff := rl.calculateBackoff(10)
if backoff > 10*time.Minute {
t.Errorf("expected backoff capped at 10min, got %v", backoff)
}
}
func TestGetTokenState_ReturnsCopy(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state1 := rl.GetTokenState("token1")
state1.FailCount = 999
state2 := rl.GetTokenState("token1")
if state2.FailCount == 999 {
t.Error("GetTokenState should return a copy")
}
}
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
rl := NewRateLimiter()
const numGoroutines = 50
const numOperations = 50
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
rl.IsTokenAvailable(tokenKey)
case 1:
rl.MarkTokenFailed(tokenKey)
case 2:
rl.MarkTokenSuccess(tokenKey)
case 3:
rl.GetTokenState(tokenKey)
case 4:
rl.CheckAndMarkSuspended(tokenKey, "test error")
case 5:
rl.ResetSuspension(tokenKey)
}
}
}(i)
}
wg.Wait()
}
func TestCalculateInterval_WithinRange(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 10 * time.Second,
MaxTokenInterval: 30 * time.Second,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
minAllowed := 7 * time.Second
maxAllowed := 40 * time.Second
for i := 0; i < 100; i++ {
interval := rl.calculateInterval()
if interval < minAllowed || interval > maxAllowed {
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
}
}
}

View File

@@ -0,0 +1,180 @@
package kiro
import (
"context"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// RefreshManager 是后台刷新器的单例管理器
type RefreshManager struct {
mu sync.Mutex
refresher *BackgroundRefresher
ctx context.Context
cancel context.CancelFunc
started bool
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
}
var (
globalRefreshManager *RefreshManager
managerOnce sync.Once
)
// GetRefreshManager 获取全局刷新管理器实例
func GetRefreshManager() *RefreshManager {
managerOnce.Do(func() {
globalRefreshManager = &RefreshManager{}
})
return globalRefreshManager
}
// Initialize 初始化后台刷新器
// baseDir: token 文件所在的目录
// cfg: 应用配置
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already initialized")
return nil
}
if baseDir == "" {
log.Warn("refresh manager: base directory not provided, skipping initialization")
return nil
}
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
if err != nil {
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
}
if resolvedBaseDir != "" {
baseDir = resolvedBaseDir
}
// 创建 token 存储库
repo := NewFileTokenRepository(baseDir)
// 创建后台刷新器,配置参数
opts := []RefresherOption{
WithInterval(time.Minute), // 每分钟检查一次
WithBatchSize(50), // 每批最多处理 50 个 token
WithConcurrency(10), // 最多 10 个并发刷新
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
}
// 如果已设置回调,传递给 BackgroundRefresher
if m.onTokenRefreshed != nil {
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
}
m.refresher = NewBackgroundRefresher(repo, opts...)
log.Infof("refresh manager: initialized with base directory %s", baseDir)
return nil
}
// Start 启动后台刷新
func (m *RefreshManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already started")
return
}
if m.refresher == nil {
log.Warn("refresh manager: not initialized, cannot start")
return
}
m.ctx, m.cancel = context.WithCancel(context.Background())
m.refresher.Start(m.ctx)
m.started = true
log.Info("refresh manager: background refresh started")
}
// Stop 停止后台刷新
func (m *RefreshManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.started {
return
}
if m.cancel != nil {
m.cancel()
}
if m.refresher != nil {
m.refresher.Stop()
}
m.started = false
log.Info("refresh manager: background refresh stopped")
}
// IsRunning 检查后台刷新是否正在运行
func (m *RefreshManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.started
}
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.refresher != nil && m.refresher.tokenRepo != nil {
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
repo.SetBaseDir(baseDir)
log.Infof("refresh manager: updated base directory to %s", baseDir)
}
}
}
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
// 可以在任何时候调用,支持运行时更新回调
// callback: 回调函数,接收 tokenID文件名和新的 token 数据
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
m.mu.Lock()
defer m.mu.Unlock()
m.onTokenRefreshed = callback
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
if m.refresher != nil {
m.refresher.callbackMu.Lock()
m.refresher.onTokenRefreshed = callback
m.refresher.callbackMu.Unlock()
}
log.Debug("refresh manager: token refresh callback registered")
}
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
func InitializeAndStart(baseDir string, cfg *config.Config) {
manager := GetRefreshManager()
if err := manager.Initialize(baseDir, cfg); err != nil {
log.Errorf("refresh manager: initialization failed: %v", err)
return
}
manager.Start()
}
// StopGlobalRefreshManager 停止全局刷新管理器
func StopGlobalRefreshManager() {
if globalRefreshManager != nil {
globalRefreshManager.Stop()
}
}

View File

@@ -0,0 +1,159 @@
// Package kiro provides refresh utilities for Kiro token management.
package kiro
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// RefreshResult contains the result of a token refresh attempt.
type RefreshResult struct {
TokenData *KiroTokenData
Error error
UsedFallback bool // True if we used the existing token as fallback
}
// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation.
// If refresh fails but the existing access token is still valid, it returns the existing token.
// This matches kiro-openai-gateway's behavior for better reliability.
//
// Parameters:
// - ctx: Context for the request
// - refreshFunc: Function to perform the actual refresh
// - existingAccessToken: Current access token (for fallback)
// - expiresAt: Expiration time of the existing token
//
// Returns:
// - RefreshResult containing the new or existing token data
func RefreshWithGracefulDegradation(
ctx context.Context,
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
existingAccessToken string,
expiresAt time.Time,
) RefreshResult {
// Try to refresh the token
newTokenData, err := refreshFunc(ctx)
if err == nil {
return RefreshResult{
TokenData: newTokenData,
Error: nil,
UsedFallback: false,
}
}
// Refresh failed - check if we can use the existing token
log.Warnf("kiro: token refresh failed: %v", err)
// Check if existing token is still valid (not expired)
if existingAccessToken != "" && time.Now().Before(expiresAt) {
remainingTime := time.Until(expiresAt)
log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second))
return RefreshResult{
TokenData: &KiroTokenData{
AccessToken: existingAccessToken,
ExpiresAt: expiresAt.Format(time.RFC3339),
},
Error: nil,
UsedFallback: true,
}
}
// Token is expired and refresh failed - return the error
return RefreshResult{
TokenData: nil,
Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err),
UsedFallback: false,
}
}
// IsTokenExpiringSoon checks if a token is expiring within the given threshold.
// Default threshold is 5 minutes if not specified.
func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool {
if threshold == 0 {
threshold = 5 * time.Minute
}
return time.Now().Add(threshold).After(expiresAt)
}
// IsTokenExpired checks if a token has already expired.
func IsTokenExpired(expiresAt time.Time) bool {
return time.Now().After(expiresAt)
}
// ParseExpiresAt parses an expiration time string in RFC3339 format.
// Returns zero time if parsing fails.
func ParseExpiresAt(expiresAtStr string) time.Time {
if expiresAtStr == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err)
return time.Time{}
}
return t
}
// RefreshConfig contains configuration for token refresh behavior.
type RefreshConfig struct {
// MaxRetries is the maximum number of refresh attempts (default: 1)
MaxRetries int
// RetryDelay is the delay between retry attempts (default: 1 second)
RetryDelay time.Duration
// RefreshThreshold is how early to refresh before expiration (default: 5 minutes)
RefreshThreshold time.Duration
// EnableGracefulDegradation allows using existing token if refresh fails (default: true)
EnableGracefulDegradation bool
}
// DefaultRefreshConfig returns the default refresh configuration.
func DefaultRefreshConfig() RefreshConfig {
return RefreshConfig{
MaxRetries: 1,
RetryDelay: time.Second,
RefreshThreshold: 5 * time.Minute,
EnableGracefulDegradation: true,
}
}
// RefreshWithRetry attempts to refresh a token with retry logic.
func RefreshWithRetry(
ctx context.Context,
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
config RefreshConfig,
) (*KiroTokenData, error) {
var lastErr error
maxAttempts := config.MaxRetries + 1
if maxAttempts < 1 {
maxAttempts = 1
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
tokenData, err := refreshFunc(ctx)
if err == nil {
if attempt > 1 {
log.Infof("kiro: token refresh succeeded on attempt %d", attempt)
}
return tokenData, nil
}
lastErr = err
log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err)
// Don't sleep after the last attempt
if attempt < maxAttempts {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(config.RetryDelay):
}
}
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr)
}

View File

@@ -9,7 +9,9 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"html"
"io"
"net"
"net/http"
"net/url"
"os"
@@ -31,6 +33,9 @@ const (
// OAuth timeout
socialAuthTimeout = 10 * time.Minute
// Default callback port for social auth HTTP server
socialAuthCallbackPort = 9876
)
// SocialProvider represents the social login provider.
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// WebCallbackResult contains the OAuth callback result from HTTP server.
type WebCallbackResult struct {
Code string
State string
Error string
}
// SocialAuthClient handles social authentication with Kiro.
type SocialAuthClient struct {
httpClient *http.Client
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
}
}
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
// Try to find an available port - use localhost like Kiro does
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
if err != nil {
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
}
}
port := listener.Addr().(*net.TCPAddr).Port
// Use http scheme for local callback server
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
resultChan := make(chan WebCallbackResult, 1)
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
if errParam != "" {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- WebCallbackResult{Error: errParam}
return
}
if state != expectedState {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- WebCallbackResult{Error: "state mismatch"}
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Successful</title></head>
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
<script>window.close();</script></body></html>`)
resultChan <- WebCallbackResult{Code: code, State: state}
})
server.Handler = mux
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("kiro social auth callback server error: %v", err)
}
}()
go func() {
select {
case <-ctx.Done():
case <-time.After(socialAuthTimeout):
case <-resultChan:
}
_ = server.Shutdown(context.Background())
}()
return redirectURI, resultChan, nil
}
// generatePKCE generates PKCE code verifier and challenge.
func generatePKCE() (verifier, challenge string, err error) {
// Generate 32 bytes of random data for verifier
@@ -140,7 +229,7 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// LoginWithSocial performs OAuth login with Google.
// LoginWithSocial performs OAuth login with Google or GitHub.
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
providerName := string(provider)
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Setup protocol handler
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
// This avoids redirect_mismatch errors with AWS Cognito
fmt.Println("\nSetting up authentication...")
// Start the local callback server
handlerPort, err := c.protocolHandler.Start(ctx)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
defer c.protocolHandler.Stop()
// Ensure protocol handler is installed and set as default
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
log.Debugf("kiro: protocol handler setup error: %v", err)
// Continue anyway - user might have set it up manually or select browser manually
} else {
// Force set our handler as default (prevents "Open with" dialog)
forceDefaultProtocolHandler()
}
// Step 2: Generate PKCE codes
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 4: Build the login URL (Kiro uses GET request with query params)
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
// Step 4: Start local HTTP callback server
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
// Step 5: Build the login URL using HTTP redirect URI
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
// Incognito mode enables multi-account support by bypassing cached sessions
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Step 5: Open browser for user authentication
// Step 6: Open browser for user authentication
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
fmt.Println("════════════════════════════════════════════════════════════")
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
fmt.Println("\n Waiting for authentication callback...")
// Step 6: Wait for callback
callback, err := c.protocolHandler.WaitForCallback(ctx)
if err != nil {
return nil, fmt.Errorf("failed to receive callback: %w", err)
}
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
if callback.State != state {
// Log state values for debugging, but don't expose in user-facing error
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
return nil, fmt.Errorf("OAuth state validation failed - please try again")
}
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 7: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: KiroRedirectURI,
}
tokenResp, err := c.CreateToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
// Step 7: Wait for callback from HTTP server
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(socialAuthTimeout):
return nil, fmt.Errorf("authentication timed out")
case callback := <-resultChan:
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
}, nil
// State is already validated by the callback server
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 8: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
}
tokenResp, err := c.CreateToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
Region: "us-east-1",
}, nil
}
}
// LoginWithGoogle performs OAuth login with Google.

View File

@@ -2,16 +2,19 @@
package kiro
import (
"bufio"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"net"
"net/http"
"os"
"strings"
"time"
@@ -24,19 +27,31 @@ import (
const (
// AWS SSO OIDC endpoints
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
// Kiro's start URL for Builder ID
builderIDStartURL = "https://view.awsapps.com/start"
// Default region for IDC
defaultIDCRegion = "us-east-1"
// Polling interval
pollInterval = 5 * time.Second
// Authorization code flow callback
authCodeCallbackPath = "/oauth/callback"
authCodeCallbackPort = 19877
// User-Agent to match official Kiro IDE
kiroUserAgent = "KiroIDE"
// IDC token refresh headers (matching Kiro IDE behavior)
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
)
// Sentinel errors for OIDC token polling
var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrSlowDown = errors.New("slow_down")
)
// SSOOIDCClient handles AWS SSO OIDC authentication.
@@ -83,6 +98,440 @@ type CreateTokenResponse struct {
RefreshToken string `json:"refreshToken"`
}
// getOIDCEndpoint returns the OIDC endpoint for the given region.
func getOIDCEndpoint(region string) string {
if region == "" {
region = defaultIDCRegion
}
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
}
// promptInput prompts the user for input with an optional default value.
func promptInput(prompt, defaultValue string) string {
reader := bufio.NewReader(os.Stdin)
if defaultValue != "" {
fmt.Printf("%s [%s]: ", prompt, defaultValue)
} else {
fmt.Printf("%s: ", prompt)
}
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return defaultValue
}
input = strings.TrimSpace(input)
if input == "" {
return defaultValue
}
return input
}
// promptSelect prompts the user to select from options using number input.
func promptSelect(prompt string, options []string) int {
reader := bufio.NewReader(os.Stdin)
for {
fmt.Println(prompt)
for i, opt := range options {
fmt.Printf(" %d) %s\n", i+1, opt)
}
fmt.Printf("Enter selection (1-%d): ", len(options))
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return 0 // Default to first option on error
}
input = strings.TrimSpace(input)
// Parse the selection
var selection int
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
continue
}
return selection - 1
}
}
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]interface{}{
"clientName": "Kiro IDE",
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
}
var result RegisterClientResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"startUrl": startURL,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
}
var result StartDeviceAuthResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"deviceCode": deviceCode,
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Check for pending authorization
if resp.StatusCode == http.StatusBadRequest {
var errResp struct {
Error string `json:"error"`
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, ErrAuthorizationPending
}
if errResp.Error == "slow_down" {
return nil, ErrSlowDown
}
}
log.Debugf("create token failed: %s", string(respBody))
return nil, fmt.Errorf("create token failed")
}
if resp.StatusCode != http.StatusOK {
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
endpoint := getOIDCEndpoint(region)
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
// Set headers matching kiro2api's IDC token refresh
// These headers are required for successful IDC token refresh
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
req.Header.Set("Connection", "keep-alive")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "*")
req.Header.Set("sec-fetch-mode", "cors")
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
StartURL: startURL,
Region: region,
}, nil
}
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Register client with the specified region
fmt.Println("\nRegistering client...")
regResp, err := c.RegisterClientWithRegion(ctx, region)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
log.Debugf("Client registered: %s", regResp.ClientID)
// Step 2: Start device authorization with IDC start URL
fmt.Println("Starting device authorization...")
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
if err != nil {
return nil, fmt.Errorf("failed to start device auth: %w", err)
}
// Step 3: Show user the verification URL
fmt.Printf("\n")
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf(" Confirm the following code in the browser:\n")
fmt.Printf(" Code: %s\n", authResp.UserCode)
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
// Set incognito mode based on config
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
if !c.cfg.IncognitoBrowser {
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
} else {
log.Debug("kiro: using incognito mode for multi-account support")
}
} else {
browser.SetIncognitoMode(true)
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Open browser
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" Please open the URL manually in your browser.")
} else {
fmt.Println(" (Browser opened automatically)")
}
// Step 4: Poll for token
fmt.Println("Waiting for authorization...")
interval := pollInterval
if authResp.Interval > 0 {
interval = time.Duration(authResp.Interval) * time.Second
}
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
browser.CloseBrowser()
return nil, ctx.Err()
case <-time.After(interval):
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
if err != nil {
if errors.Is(err, ErrAuthorizationPending) {
fmt.Print(".")
continue
}
if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second
continue
}
browser.CloseBrowser()
return nil, fmt.Errorf("token creation failed: %w", err)
}
fmt.Println("\n\n✓ Authorization successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Fetch user email
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "idc",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
StartURL: startURL,
Region: region,
}, nil
}
}
// Close browser on timeout
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Prompt for login method
options := []string{
"Use with Builder ID (personal AWS account)",
"Use with IDC Account (organization SSO)",
}
selection := promptSelect("\n? Select login method:", options)
if selection == 0 {
// Builder ID flow - use existing implementation
return c.LoginWithBuilderID(ctx)
}
// IDC flow - prompt for start URL and region
fmt.Println()
startURL := promptInput("? Enter Start URL", "")
if startURL == "" {
return nil, fmt.Errorf("start URL is required for IDC login")
}
region := promptInput("? Enter Region", defaultIDCRegion)
return c.LoginWithIDC(ctx, startURL, region)
}
// RegisterClient registers a new OIDC client with AWS.
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
payload := map[string]interface{}{
@@ -211,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, fmt.Errorf("authorization_pending")
return nil, ErrAuthorizationPending
}
if errResp.Error == "slow_down" {
return nil, fmt.Errorf("slow_down")
return nil, ErrSlowDown
}
}
log.Debugf("create token failed: %s", string(respBody))
@@ -235,6 +684,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
}
// RefreshToken refreshes an access token using the refresh token.
// Includes retry logic and improved error handling for better reliability.
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
payload := map[string]string{
"clientId": clientID,
@@ -252,8 +702,13 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
if err != nil {
return nil, err
}
// Set headers matching Kiro IDE behavior for better compatibility
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", kiroUserAgent)
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("User-Agent", "node")
req.Header.Set("Accept", "*/*")
resp, err := c.httpClient.Do(req)
if err != nil {
@@ -267,8 +722,8 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
log.Warnf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
}
var result CreateTokenResponse
@@ -286,6 +741,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
Region: defaultIDCRegion,
}, nil
}
@@ -359,12 +815,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
case <-time.After(interval):
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "authorization_pending") {
if errors.Is(err, ErrAuthorizationPending) {
fmt.Print(".")
continue
}
if strings.Contains(errStr, "slow_down") {
if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second
continue
}
@@ -402,16 +857,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}
}
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
// Falls back to JWT parsing if userinfo fails.
@@ -918,6 +1374,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
Region: defaultIDCRegion,
}, nil
}
}

View File

@@ -9,6 +9,8 @@ import (
// KiroTokenStorage holds the persistent token data for Kiro authentication.
type KiroTokenStorage struct {
// Type is the provider type for management UI recognition (must be "kiro")
Type string `json:"type"`
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
Provider string `json:"provider"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
// ClientID is the OAuth client ID (required for token refresh)
ClientID string `json:"client_id,omitempty"`
// ClientSecret is the OAuth client secret (required for token refresh)
ClientSecret string `json:"client_secret,omitempty"`
// Region is the AWS region
Region string `json:"region,omitempty"`
// StartURL is the AWS Identity Center start URL (for IDC auth)
StartURL string `json:"start_url,omitempty"`
// Email is the user's email address
Email string `json:"email,omitempty"`
}
// SaveTokenToFile persists the token storage to the specified file path.
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
ExpiresAt: s.ExpiresAt,
AuthMethod: s.AuthMethod,
Provider: s.Provider,
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
Region: s.Region,
StartURL: s.StartURL,
Email: s.Email,
}
}

View File

@@ -0,0 +1,274 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
type FileTokenRepository struct {
mu sync.RWMutex
baseDir string
}
// NewFileTokenRepository 创建一个新的文件 token 存储库
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
return &FileTokenRepository{
baseDir: baseDir,
}
}
// SetBaseDir 设置基础目录
func (r *FileTokenRepository) SetBaseDir(dir string) {
r.mu.Lock()
r.baseDir = strings.TrimSpace(dir)
r.mu.Unlock()
}
// FindOldestUnverified 查找需要刷新的 token按最后验证时间排序
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
log.Debug("token repository: base directory not configured")
return nil
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil // 忽略错误,继续遍历
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
// 只处理 kiro 相关的 token 文件
if !strings.HasPrefix(d.Name(), "kiro-") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
log.Debugf("token repository: failed to read token file %s: %v", path, err)
return nil
}
if token != nil && token.RefreshToken != "" {
// 检查 token 是否需要刷新(过期前 5 分钟)
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
tokens = append(tokens, token)
}
}
return nil
})
if err != nil {
log.Warnf("token repository: error walking directory: %v", err)
}
// 按最后验证时间排序(最旧的优先)
sort.Slice(tokens, func(i, j int) bool {
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
})
// 限制返回数量
if limit > 0 && len(tokens) > limit {
tokens = tokens[:limit]
}
return tokens
}
// UpdateToken 更新 token 并持久化到文件
func (r *FileTokenRepository) UpdateToken(token *Token) error {
if token == nil {
return fmt.Errorf("token repository: token is nil")
}
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return fmt.Errorf("token repository: base directory not configured")
}
// 构建文件路径
filePath := filepath.Join(baseDir, token.ID)
if !strings.HasSuffix(filePath, ".json") {
filePath += ".json"
}
// 读取现有文件内容
existingData := make(map[string]any)
if data, err := os.ReadFile(filePath); err == nil {
_ = json.Unmarshal(data, &existingData)
}
// 更新字段
existingData["access_token"] = token.AccessToken
existingData["refresh_token"] = token.RefreshToken
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
if !token.ExpiresAt.IsZero() {
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
}
// 保持原有的关键字段
if token.ClientID != "" {
existingData["client_id"] = token.ClientID
}
if token.ClientSecret != "" {
existingData["client_secret"] = token.ClientSecret
}
if token.AuthMethod != "" {
existingData["auth_method"] = token.AuthMethod
}
if token.Region != "" {
existingData["region"] = token.Region
}
if token.StartURL != "" {
existingData["start_url"] = token.StartURL
}
// 序列化并写入文件
raw, err := json.MarshalIndent(existingData, "", " ")
if err != nil {
return fmt.Errorf("token repository: marshal failed: %w", err)
}
// 原子写入:先写入临时文件,再重命名
tmpPath := filePath + ".tmp"
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
return fmt.Errorf("token repository: write temp file failed: %w", err)
}
if err := os.Rename(tmpPath, filePath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("token repository: rename failed: %w", err)
}
log.Debugf("token repository: updated token %s", token.ID)
return nil
}
// readTokenFile 从文件读取 token
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var metadata map[string]any
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, err
}
// 检查是否是 kiro token
tokenType, _ := metadata["type"].(string)
if tokenType != "kiro" {
return nil, nil
}
// 检查 auth_method (case-insensitive comparison to handle "IdC", "IDC", "idc", etc.)
authMethod, _ := metadata["auth_method"].(string)
authMethod = strings.ToLower(authMethod)
if authMethod != "idc" && authMethod != "builder-id" {
return nil, nil // 只处理 IDC 和 Builder ID token
}
token := &Token{
ID: filepath.Base(path),
AuthMethod: authMethod,
}
// 解析各字段
if v, ok := metadata["access_token"].(string); ok {
token.AccessToken = v
}
if v, ok := metadata["refresh_token"].(string); ok {
token.RefreshToken = v
}
if v, ok := metadata["client_id"].(string); ok {
token.ClientID = v
}
if v, ok := metadata["client_secret"].(string); ok {
token.ClientSecret = v
}
if v, ok := metadata["region"].(string); ok {
token.Region = v
}
if v, ok := metadata["start_url"].(string); ok {
token.StartURL = v
}
if v, ok := metadata["provider"].(string); ok {
token.Provider = v
}
// 解析时间字段
if v, ok := metadata["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
token.ExpiresAt = t
}
}
if v, ok := metadata["last_refresh"].(string); ok {
if t, err := time.Parse(time.RFC3339, v); err == nil {
token.LastVerified = t
}
}
return token, nil
}
// ListKiroTokens 列出所有 Kiro token用于调试
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return nil, fmt.Errorf("token repository: base directory not configured")
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil
}
if d.IsDir() {
return nil
}
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
return nil
}
if token != nil {
tokens = append(tokens, token)
}
return nil
})
return tokens, err
}

View File

@@ -0,0 +1,243 @@
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
// This file implements usage quota checking and monitoring.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
// UsageQuotaResponse represents the API response structure for usage quota checking.
type UsageQuotaResponse struct {
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
NextDateReset float64 `json:"nextDateReset,omitempty"`
}
// UsageBreakdownExtended represents detailed usage information for quota checking.
// Note: UsageBreakdown is already defined in codewhisperer_client.go
type UsageBreakdownExtended struct {
ResourceType string `json:"resourceType"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
}
// FreeTrialInfoExtended represents free trial usage information.
type FreeTrialInfoExtended struct {
FreeTrialStatus string `json:"freeTrialStatus"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
}
// QuotaStatus represents the quota status for a token.
type QuotaStatus struct {
TotalLimit float64
CurrentUsage float64
RemainingQuota float64
IsExhausted bool
ResourceType string
NextReset time.Time
}
// UsageChecker provides methods for checking token quota usage.
type UsageChecker struct {
httpClient *http.Client
endpoint string
}
// NewUsageChecker creates a new UsageChecker instance.
func NewUsageChecker(cfg *config.Config) *UsageChecker {
return &UsageChecker{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
return &UsageChecker{
httpClient: client,
endpoint: awsKiroEndpoint,
}
}
// CheckUsage retrieves usage limits for the given token.
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
if tokenData == nil {
return nil, fmt.Errorf("token data is nil")
}
if tokenData.AccessToken == "" {
return nil, fmt.Errorf("access token is empty")
}
payload := map[string]interface{}{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", targetGetUsage)
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result UsageQuotaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse usage response: %w", err)
}
return &result, nil
}
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
tokenData := &KiroTokenData{
AccessToken: accessToken,
ProfileArn: profileArn,
}
return c.CheckUsage(ctx, tokenData)
}
// GetRemainingQuota calculates the remaining quota from usage limits.
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 0
}
var totalRemaining float64
for _, breakdown := range usage.UsageBreakdownList {
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
if remaining > 0 {
totalRemaining += remaining
}
if breakdown.FreeTrialInfo != nil {
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
totalRemaining += freeRemaining
}
}
}
return totalRemaining
}
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return true
}
for _, breakdown := range usage.UsageBreakdownList {
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
return false
}
if breakdown.FreeTrialInfo != nil {
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
return false
}
}
}
return true
}
// GetQuotaStatus retrieves a comprehensive quota status for a token.
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
usage, err := c.CheckUsage(ctx, tokenData)
if err != nil {
return nil, err
}
status := &QuotaStatus{
IsExhausted: IsQuotaExhausted(usage),
}
if len(usage.UsageBreakdownList) > 0 {
breakdown := usage.UsageBreakdownList[0]
status.TotalLimit = breakdown.UsageLimitWithPrecision
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
status.ResourceType = breakdown.ResourceType
if breakdown.FreeTrialInfo != nil {
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
status.RemainingQuota += freeRemaining
}
}
}
if usage.NextDateReset > 0 {
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
}
return status, nil
}
// CalculateAvailableCount calculates the available request count based on usage limits.
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
return GetRemainingQuota(usage)
}
// GetUsagePercentage calculates the usage percentage.
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 100.0
}
var totalLimit, totalUsage float64
for _, breakdown := range usage.UsageBreakdownList {
totalLimit += breakdown.UsageLimitWithPrecision
totalUsage += breakdown.CurrentUsageWithPrecision
if breakdown.FreeTrialInfo != nil {
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
}
}
if totalLimit == 0 {
return 100.0
}
return (totalUsage / totalLimit) * 100
}

View File

@@ -3,7 +3,7 @@ package cache
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strings"
"sync"
"time"
)
@@ -16,23 +16,26 @@ type SignatureEntry struct {
const (
// SignatureCacheTTL is how long signatures are valid
SignatureCacheTTL = 1 * time.Hour
// MaxEntriesPerSession limits memory usage per session
MaxEntriesPerSession = 100
SignatureCacheTTL = 3 * time.Hour
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
SignatureTextHashLen = 16
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// CacheCleanupInterval controls how often stale entries are purged
CacheCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCache is the inner map type
type sessionCache struct {
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
var cacheCleanupOnce sync.Once
// groupCache is the inner map type
type groupCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
@@ -43,122 +46,150 @@ func hashText(text string) string {
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
// getOrCreateGroupCache gets or creates a cache bucket for a model group
func getOrCreateGroupCache(groupKey string) *groupCache {
// Start background cleanup on first access
cacheCleanupOnce.Do(startCacheCleanup)
if val, ok := signatureCache.Load(groupKey); ok {
return val.(*groupCache)
}
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
return actual.(*sessionCache)
sc := &groupCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
return actual.(*groupCache)
}
// CacheSignature stores a thinking signature for a given session and text.
// startCacheCleanup launches a background goroutine that periodically
// removes caches where all entries have expired.
func startCacheCleanup() {
go func() {
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCaches()
}
}()
}
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
func purgeExpiredCaches() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*groupCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, k)
}
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove cache bucket if empty
if isEmpty {
signatureCache.Delete(key)
}
return true
})
}
// CacheSignature stores a thinking signature for a given model group and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {
if sessionID == "" || text == "" || signature == "" {
func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
sc := getOrCreateSession(sessionID)
groupKey := GetModelGroup(modelName)
textHash := hashText(text)
sc := getOrCreateGroupCache(groupKey)
sc.mu.Lock()
defer sc.mu.Unlock()
// Evict expired entries if at capacity
if len(sc.entries) >= MaxEntriesPerSession {
now := time.Now()
for key, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, key)
}
}
// If still at capacity, remove oldest entries
if len(sc.entries) >= MaxEntriesPerSession {
// Find and remove oldest quarter
oldest := make([]struct {
key string
ts time.Time
}, 0, len(sc.entries))
for key, entry := range sc.entries {
oldest = append(oldest, struct {
key string
ts time.Time
}{key, entry.Timestamp})
}
// Sort by timestamp (oldest first) using sort.Slice
sort.Slice(oldest, func(i, j int) bool {
return oldest[i].ts.Before(oldest[j].ts)
})
toRemove := len(oldest) / 4
if toRemove < 1 {
toRemove = 1
}
for i := 0; i < toRemove; i++ {
delete(sc.entries, oldest[i].key)
}
}
}
sc.entries[textHash] = SignatureEntry{
Signature: signature,
Timestamp: time.Now(),
}
}
// GetCachedSignature retrieves a cached signature for a given session and text.
// GetCachedSignature retrieves a cached signature for a given model group and text.
// Returns empty string if not found or expired.
func GetCachedSignature(sessionID, text string) string {
if sessionID == "" || text == "" {
return ""
}
func GetCachedSignature(modelName, text string) string {
groupKey := GetModelGroup(modelName)
val, ok := signatureCache.Load(sessionID)
if !ok {
if text == "" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*sessionCache)
val, ok := signatureCache.Load(groupKey)
if !ok {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*groupCache)
textHash := hashText(text)
sc.mu.RLock()
entry, exists := sc.entries[textHash]
sc.mu.RUnlock()
now := time.Now()
sc.mu.Lock()
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
// Check if expired
if time.Since(entry.Timestamp) > SignatureCacheTTL {
sc.mu.Lock()
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
// Refresh TTL on access (sliding expiration).
entry.Timestamp = now
sc.entries[textHash] = entry
sc.mu.Unlock()
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Delete(sessionID)
} else {
// ClearSignatureCache clears signature cache for a specific model group or all groups.
func ClearSignatureCache(modelName string) {
if modelName == "" {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
return
}
groupKey := GetModelGroup(modelName)
signatureCache.Delete(groupKey)
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(signature string) bool {
return signature != "" && len(signature) >= MinValidSignatureLen
func HasValidSignature(modelName, signature string) bool {
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
}
func GetModelGroup(modelName string) string {
if strings.Contains(modelName, "gpt") {
return "gpt"
} else if strings.Contains(modelName, "claude") {
return "claude"
} else if strings.Contains(modelName, "gemini") {
return "gemini"
}
return modelName
}

View File

@@ -5,38 +5,40 @@ import (
"time"
)
const testModelName = "claude-sonnet-4-5"
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-session-1"
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(sessionID, text, signature)
CacheSignature(testModelName, text, signature)
// Retrieve signature
retrieved := GetCachedSignature(sessionID, text)
retrieved := GetCachedSignature(testModelName, text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
}
func TestCacheSignature_DifferentSessions(t *testing.T) {
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
ClearSignatureCache("")
text := "Same text in different sessions"
text := "Same text across models"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("session-a", text, sig1)
CacheSignature("session-b", text, sig2)
geminiModel := "gemini-3-pro-preview"
CacheSignature(testModelName, text, sig1)
CacheSignature(geminiModel, text, sig2)
if GetCachedSignature("session-a", text) != sig1 {
t.Error("Session-a signature mismatch")
if GetCachedSignature(testModelName, text) != sig1 {
t.Error("Claude signature mismatch")
}
if GetCachedSignature("session-b", text) != sig2 {
t.Error("Session-b signature mismatch")
if GetCachedSignature(geminiModel, text) != sig2 {
t.Error("Gemini signature mismatch")
}
}
@@ -44,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
if got := GetCachedSignature(testModelName, "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("session-x", "text-b"); got != "" {
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
@@ -59,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "text", "")
CacheSignature("session", "text", "short") // Too short
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature(testModelName, "text", "")
CacheSignature(testModelName, "text", "short") // Too short
if got := GetCachedSignature("session", "text"); got != "" {
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
@@ -72,31 +73,27 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-short-sig"
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(sessionID, text, shortSig)
CacheSignature(testModelName, text, shortSig)
if got := GetCachedSignature(sessionID, text); got != "" {
if got := GetCachedSignature(testModelName, text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
func TestClearSignatureCache_SpecificSession(t *testing.T) {
func TestClearSignatureCache_ModelGroup(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != sig {
t.Error("session-2 should still exist")
if got := GetCachedSignature(testModelName, "text"); got != sig {
t.Error("signature should remain when clearing unknown session")
}
}
@@ -104,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Error("text should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != "" {
t.Error("session-2 should be cleared")
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
t.Error("text-2 should be cleared")
}
}
func TestHasValidSignature(t *testing.T) {
tests := []struct {
name string
modelName string
signature string
expected bool
}{
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
{"empty string", "", false},
{"short signature", "abc", false},
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
{"empty string", testModelName, "", false},
{"short signature", testModelName, "abc", false},
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.signature)
result := HasValidSignature(tt.modelName, tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
@@ -143,21 +142,19 @@ func TestHasValidSignature(t *testing.T) {
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
sessionID := "hash-test-session"
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text1, sig1)
CacheSignature(sessionID, text2, sig2)
CacheSignature(testModelName, text1, sig1)
CacheSignature(testModelName, text2, sig2)
if GetCachedSignature(sessionID, text1) != sig1 {
if GetCachedSignature(testModelName, text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(sessionID, text2) != sig2 {
if GetCachedSignature(testModelName, text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
@@ -165,13 +162,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
sessionID := "unicode-session"
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(sessionID, text, sig)
CacheSignature(testModelName, text, sig)
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
@@ -179,15 +175,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
sessionID := "overwrite-session"
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(sessionID, text, sig1)
CacheSignature(sessionID, text, sig2) // Overwrite
CacheSignature(testModelName, text, sig1)
CacheSignature(testModelName, text, sig2) // Overwrite
if got := GetCachedSignature(sessionID, text); got != sig2 {
if got := GetCachedSignature(testModelName, text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
@@ -199,14 +194,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
sessionID := "expiration-test"
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text, sig)
CacheSignature(testModelName, text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(sessionID, text); got != sig {
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}

View File

@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)

View File

@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
}
loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID,
Metadata: map[string]string{},
Prompt: callbackPrompt,
NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: callbackPrompt,
}
authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
Prompt: callbackPrompt,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Prompt: callbackPrompt,
})
if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient)
@@ -116,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
}
activatedProjects := make([]string, 0, len(projectSelections))
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
@@ -132,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
if finalID == "" {
finalID = candidateID
}
// Skip duplicates
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
@@ -259,7 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// Interactive prompt for free users
fmt.Printf("\nGoogle returned a different project ID:\n")
fmt.Printf(" Requested (frontend): %s\n", projectID)
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
fmt.Printf(" This is normal for free tier users.\n\n")
fmt.Printf("Which project ID would you like to use?\n")
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
fmt.Printf("Enter choice [1]: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)
if choice == "2" {
log.Infof("Using frontend project ID: %s", projectID)
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
finalProjectID = projectID
} else {
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
finalProjectID = responseProjectID
}
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}

View File

@@ -19,6 +19,9 @@ type LoginOptions struct {
// NoBrowser indicates whether to skip opening the browser automatically.
NoBrowser bool
// CallbackPort overrides the local OAuth callback port when set (>0).
CallbackPort int
// Prompt allows the caller to provide interactive input when needed.
Prompt func(prompt string) (string, error)
}
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: promptFn,
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)

View File

@@ -6,12 +6,14 @@ package config
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
)
@@ -39,6 +41,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -60,9 +65,17 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// Routing controls credential selection behavior.
Routing RoutingConfig `yaml:"routing" json:"routing"`
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
// CodexInstructionsEnabled controls whether official Codex instructions are injected.
// When false (default), CodexInstructionsForModel returns immediately without modification.
// When true, the original instruction injection logic is used.
CodexInstructionsEnabled bool `yaml:"codex-instructions-enabled" json:"codex-instructions-enabled"`
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
@@ -90,8 +103,17 @@ type Config struct {
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
@@ -136,6 +158,23 @@ type QuotaExceeded struct {
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
}
// RoutingConfig configures how credentials are selected for requests.
type RoutingConfig struct {
// Strategy selects the credential selection strategy.
// Supported values: "round-robin" (default), "fill-first".
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// OAuthModelAlias defines a model ID alias for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type OAuthModelAlias struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
@@ -146,6 +185,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
}
// AmpCode groups Amp CLI integration settings including upstream routing,
@@ -157,6 +201,11 @@ type AmpCode struct {
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
@@ -172,12 +221,37 @@ type AmpCode struct {
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
Default []PayloadRule `yaml:"default" json:"default"`
// DefaultRaw defines rules that set raw JSON values only when they are missing.
DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"`
// Override defines rules that always set parameters, overwriting any existing values.
Override []PayloadRule `yaml:"override" json:"override"`
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
// Filter defines rules that remove parameters from the payload by JSON path.
Filter []PayloadFilterRule `yaml:"filter" json:"filter"`
}
// PayloadFilterRule describes a rule to remove specific JSON paths from matching model payloads.
type PayloadFilterRule struct {
// Models lists model entries with name pattern and protocol constraint.
Models []PayloadModelRule `yaml:"models" json:"models"`
// Params lists JSON paths (gjson/sjson syntax) to remove from the payload.
Params []string `yaml:"params" json:"params"`
}
// PayloadRule describes a single rule targeting a list of models with parameter updates.
@@ -185,6 +259,7 @@ type PayloadRule struct {
// Models lists model entries with name pattern and protocol constraint.
Models []PayloadModelRule `yaml:"models" json:"models"`
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
// For *-raw rules, values are treated as raw JSON fragments (strings are used as-is).
Params map[string]any `yaml:"params" json:"params"`
}
@@ -196,12 +271,35 @@ type PayloadModelRule struct {
Protocol string `yaml:"protocol" json:"protocol"`
}
// CloakConfig configures request cloaking for non-Claude-Code clients.
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
type CloakConfig struct {
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
// - "always": always apply cloaking regardless of client
// - "never": never apply cloaking
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// StrictMode controls how system prompts are handled when cloaking.
// - false (default): prepend Claude Code prompt to user system messages
// - true: strip all user system messages, keep only Claude Code prompt
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
// including the API key itself and an optional base URL for the API endpoint.
type ClaudeKey struct {
// APIKey is the authentication key for accessing Claude API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -220,8 +318,14 @@ type ClaudeKey struct {
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
// Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
}
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
func (k ClaudeKey) GetBaseURL() string { return k.BaseURL }
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
type ClaudeModel struct {
// Name is the upstream model identifier used when issuing requests.
@@ -231,12 +335,19 @@ type ClaudeModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct {
// APIKey is the authentication key for accessing Codex API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -247,6 +358,9 @@ type CodexKey struct {
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -254,12 +368,31 @@ type CodexKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k CodexKey) GetAPIKey() string { return k.APIKey }
func (k CodexKey) GetBaseURL() string { return k.BaseURL }
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct {
// APIKey is the authentication key for accessing Gemini API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -269,6 +402,9 @@ type GeminiKey struct {
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
@@ -276,6 +412,21 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k GeminiKey) GetAPIKey() string { return k.APIKey }
func (k GeminiKey) GetBaseURL() string { return k.BaseURL }
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
type KiroKey struct {
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
@@ -311,6 +462,10 @@ type OpenAICompatibility struct {
// Name is the identifier for this OpenAI compatibility configuration.
Name string `yaml:"name" json:"name"`
// Priority controls selection preference when multiple providers or credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -346,6 +501,9 @@ type OpenAICompatibilityModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
func (m OpenAICompatibilityModel) GetAlias() string { return m.Alias }
// LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, applies environment variable overrides,
// and returns it.
@@ -364,6 +522,15 @@ func LoadConfig(configFile string) (*Config, error) {
// If optional is true and the file is missing, it returns an empty Config.
// If optional is true and the file is empty or invalid, it returns an empty Config.
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Perform oauth-model-alias migration before loading config.
// This migrates oauth-model-mappings to oauth-model-alias if needed.
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
// Log warning but don't fail - config loading should still work
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
} else if migrated {
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
}
// Read the entire configuration file into memory.
data, err := os.ReadFile(configFile)
if err != nil {
@@ -391,7 +558,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.DisableCooling = false
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
@@ -460,6 +627,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name aliases.
cfg.SanitizeOAuthModelAlias()
// Validate raw payload rules and drop invalid entries.
cfg.SanitizePayloadRules()
if cfg.legacyMigrationPending {
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
if !optional && configFile != "" {
@@ -476,6 +649,99 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return &cfg, nil
}
// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules.
func (cfg *Config) SanitizePayloadRules() {
if cfg == nil {
return
}
cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw")
cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw")
}
func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule {
if len(rules) == 0 {
return rules
}
out := make([]PayloadRule, 0, len(rules))
for i := range rules {
rule := rules[i]
if len(rule.Params) == 0 {
continue
}
invalid := false
for path, value := range rule.Params {
raw, ok := payloadRawString(value)
if !ok {
continue
}
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 || !json.Valid(trimmed) {
log.WithFields(log.Fields{
"section": section,
"rule_index": i + 1,
"param": path,
}).Warn("payload rule dropped: invalid raw JSON")
invalid = true
break
}
}
if invalid {
continue
}
out = append(out, rule)
}
return out
}
func payloadRawString(value any) ([]byte, bool) {
switch typed := value.(type) {
case string:
return []byte(typed), true
case []byte:
return typed, true
default:
return nil, false
}
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
func (cfg *Config) SanitizeOAuthModelAlias() {
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
clean := make([]OAuthModelAlias, 0, len(aliases))
for _, entry := range aliases {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelAlias = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries.
@@ -730,6 +996,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
removeLegacyGenerativeLanguageKeys(original.Content[0])
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-model-alias")
// Merge generated into original in-place, preserving comments/order of existing nodes.
mergeMappingPreserve(original.Content[0], generated.Content[0])
@@ -861,8 +1128,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
}
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended
// to dst at the end, copying their node structure from src.
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
@@ -873,20 +1140,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src)
return
}
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 {
// Merge into existing value node
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv)
} else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
// New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue
}
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
}
}
@@ -969,32 +1235,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1
}
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
switch key {
case "generative-language-api-key",
"gemini-api-key",
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.
func isZeroValueNode(node *yaml.Node) bool {
if node == nil {
return true
}
switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode:
return node.Tag == "!!null"
default:
return false
switch node.Tag {
case "!!bool":
return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
}
return false
}
// deepCopyNode creates a deep copy of a yaml.Node graph.
@@ -1204,6 +1487,16 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
}
srcIdx := findMapKeyIndex(srcRoot, key)
if srcIdx < 0 {
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
//
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
// When users delete the last channel from oauth-model-alias via the management API,
// we want that deletion to persist across hot reloads and restarts.
if key == "oauth-model-alias" {
dstRoot.Content[dstIdx+1] = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
return
}
removeMapKey(dstRoot, key)
return
}

View File

@@ -0,0 +1,275 @@
package config
import (
"os"
"strings"
"gopkg.in/yaml.v3"
)
// antigravityModelConversionTable maps old built-in aliases to actual model names
// for the antigravity channel during migration.
var antigravityModelConversionTable = map[string]string{
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
"gemini-3-pro-image-preview": "gemini-3-pro-image",
"gemini-3-pro-preview": "gemini-3-pro-high",
"gemini-3-flash-preview": "gemini-3-flash",
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
}
}
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
// to oauth-model-alias at startup. Returns true if migration was performed.
//
// Migration flow:
// 1. Check if oauth-model-alias exists -> skip migration
// 2. Check if oauth-model-mappings exists -> convert and migrate
// - For antigravity channel, convert old built-in aliases to actual model names
//
// 3. Neither exists -> add default antigravity config
func MigrateOAuthModelAlias(configFile string) (bool, error) {
data, err := os.ReadFile(configFile)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
if len(data) == 0 {
return false, nil
}
// Parse YAML into node tree to preserve structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
return false, nil
}
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
return false, nil
}
rootMap := root.Content[0]
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
return false, nil
}
// Check if oauth-model-alias already exists
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
return false, nil
}
// Check if oauth-model-mappings exists
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
if oldIdx >= 0 {
// Migrate from old field
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
}
// Neither field exists - add default antigravity config
return addDefaultAntigravityConfig(configFile, &root, rootMap)
}
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
if oldIdx+1 >= len(rootMap.Content) {
return false, nil
}
oldValue := rootMap.Content[oldIdx+1]
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
return false, nil
}
// Parse the old aliases
oldAliases := parseOldAliasNode(oldValue)
if len(oldAliases) == 0 {
// Remove the old field and write
removeMapKeyByIndex(rootMap, oldIdx)
return writeYAMLNode(configFile, root)
}
// Convert model names for antigravity channel
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
for channel, entries := range oldAliases {
converted := make([]OAuthModelAlias, 0, len(entries))
for _, entry := range entries {
newEntry := OAuthModelAlias{
Name: entry.Name,
Alias: entry.Alias,
Fork: entry.Fork,
}
// Convert model names for antigravity channel
if strings.EqualFold(channel, "antigravity") {
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
newEntry.Name = actual
}
}
converted = append(converted, newEntry)
}
newAliases[channel] = converted
}
// For antigravity channel, supplement missing default aliases
if antigravityEntries, exists := newAliases["antigravity"]; exists {
// Build a set of already configured model names (upstream names)
configuredModels := make(map[string]bool, len(antigravityEntries))
for _, entry := range antigravityEntries {
configuredModels[entry.Name] = true
}
// Add missing default aliases
for _, defaultAlias := range defaultAntigravityAliases() {
if !configuredModels[defaultAlias.Name] {
antigravityEntries = append(antigravityEntries, defaultAlias)
}
}
newAliases["antigravity"] = antigravityEntries
}
// Build new node
newNode := buildOAuthModelAliasNode(newAliases)
// Replace old key with new key and value
rootMap.Content[oldIdx].Value = "oauth-model-alias"
rootMap.Content[oldIdx+1] = newNode
return writeYAMLNode(configFile, root)
}
// addDefaultAntigravityConfig adds the default antigravity configuration
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
defaults := map[string][]OAuthModelAlias{
"antigravity": defaultAntigravityAliases(),
}
newNode := buildOAuthModelAliasNode(defaults)
// Add new key-value pair
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
rootMap.Content = append(rootMap.Content, keyNode, newNode)
return writeYAMLNode(configFile, root)
}
// parseOldAliasNode parses the old oauth-model-mappings node structure
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
if node == nil || node.Kind != yaml.MappingNode {
return nil
}
result := make(map[string][]OAuthModelAlias)
for i := 0; i+1 < len(node.Content); i += 2 {
channelNode := node.Content[i]
entriesNode := node.Content[i+1]
if channelNode == nil || entriesNode == nil {
continue
}
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
continue
}
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
for _, entryNode := range entriesNode.Content {
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
continue
}
entry := parseAliasEntry(entryNode)
if entry.Name != "" && entry.Alias != "" {
entries = append(entries, entry)
}
}
if len(entries) > 0 {
result[channel] = entries
}
}
return result
}
// parseAliasEntry parses a single alias entry node
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
var entry OAuthModelAlias
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valNode := node.Content[i+1]
if keyNode == nil || valNode == nil {
continue
}
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
case "name":
entry.Name = strings.TrimSpace(valNode.Value)
case "alias":
entry.Alias = strings.TrimSpace(valNode.Value)
case "fork":
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
}
}
return entry
}
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
for channel, entries := range aliases {
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
for _, entry := range entries {
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
)
if entry.Fork {
entryNode.Content = append(entryNode.Content,
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
)
}
entriesNode.Content = append(entriesNode.Content, entryNode)
}
node.Content = append(node.Content, channelNode, entriesNode)
}
return node
}
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return
}
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
return
}
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
}
// writeYAMLNode writes the YAML node tree back to file
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
f, err := os.Create(configFile)
if err != nil {
return false, err
}
defer f.Close()
enc := yaml.NewEncoder(f)
enc.SetIndent(2)
if err := enc.Encode(root); err != nil {
return false, err
}
if err := enc.Close(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -0,0 +1,242 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-alias:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration when oauth-model-alias already exists")
}
// Verify file unchanged
data, _ := os.ReadFile(configFile)
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("file should still contain oauth-model-alias")
}
}
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `oauth-model-mappings:
gemini-cli:
- name: "gemini-2.5-pro"
alias: "g2.5p"
fork: true
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify new field exists and old field removed
data, _ := os.ReadFile(configFile)
if strings.Contains(string(data), "oauth-model-mappings:") {
t.Fatal("old field should be removed")
}
if !strings.Contains(string(data), "oauth-model-alias:") {
t.Fatal("new field should exist")
}
// Parse and verify structure
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
t.Fatal(err)
}
}
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
// Use old model names that should be converted
content := `oauth-model-mappings:
antigravity:
- name: "gemini-2.5-computer-use-preview-10-2025"
alias: "computer-use"
- name: "gemini-3-pro-preview"
alias: "g3p"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify model names were converted
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
}
if !strings.Contains(content, "gemini-3-pro-high") {
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
}
// Verify missing default aliases were supplemented
if !strings.Contains(content, "gemini-3-pro-image") {
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
}
if !strings.Contains(content, "gemini-3-flash") {
t.Fatal("expected missing default alias gemini-3-flash to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5") {
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
}
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
}
if !strings.Contains(content, "claude-opus-4-5-thinking") {
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
}
}
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to add default config")
}
// Verify default antigravity config was added
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "oauth-model-alias:") {
t.Fatal("expected oauth-model-alias to be added")
}
if !strings.Contains(content, "antigravity:") {
t.Fatal("expected antigravity channel to be added")
}
if !strings.Contains(content, "rev19-uic3-1p") {
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
}
}
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
content := `debug: true
port: 8080
oauth-model-mappings:
gemini-cli:
- name: "test"
alias: "t"
api-keys:
- "key1"
- "key2"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !migrated {
t.Fatal("expected migration to occur")
}
// Verify other config preserved
data, _ := os.ReadFile(configFile)
content = string(data)
if !strings.Contains(content, "debug: true") {
t.Fatal("expected debug field to be preserved")
}
if !strings.Contains(content, "port: 8080") {
t.Fatal("expected port field to be preserved")
}
if !strings.Contains(content, "api-keys:") {
t.Fatal("expected api-keys field to be preserved")
}
}
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
t.Parallel()
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
if err != nil {
t.Fatalf("unexpected error for nonexistent file: %v", err)
}
if migrated {
t.Fatal("expected no migration for nonexistent file")
}
}
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
configFile := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
t.Fatal(err)
}
migrated, err := MigrateOAuthModelAlias(configFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if migrated {
t.Fatal("expected no migration for empty file")
}
}

View File

@@ -0,0 +1,56 @@
package config
import "testing"
func TestSanitizeOAuthModelAlias_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["codex"]
if len(aliases) != 2 {
t.Fatalf("expected 2 sanitized aliases, got %d", len(aliases))
}
if aliases[0].Name != "gpt-5" || aliases[0].Alias != "g5" || !aliases[0].Fork {
t.Fatalf("expected first alias to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", aliases[0].Name, aliases[0].Alias, aliases[0].Fork)
}
if aliases[1].Name != "gpt-6" || aliases[1].Alias != "g6" || aliases[1].Fork {
t.Fatalf("expected second alias to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", aliases[1].Name, aliases[1].Alias, aliases[1].Fork)
}
}
func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"antigravity": {
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["antigravity"]
expected := []OAuthModelAlias{
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
}
if len(aliases) != len(expected) {
t.Fatalf("expected %d sanitized aliases, got %d", len(expected), len(aliases))
}
for i, exp := range expected {
if aliases[i].Name != exp.Name || aliases[i].Alias != exp.Alias || aliases[i].Fork != exp.Fork {
t.Fatalf("expected alias %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, aliases[i].Name, aliases[i].Alias, aliases[i].Fork)
}
}
}

View File

@@ -22,6 +22,25 @@ type SDKConfig struct {
// Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
// <= 0 disables keep-alives. Value is in seconds.
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
}
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
// AccessConfig groups request authentication providers.

View File

@@ -13,6 +13,10 @@ type VertexCompatKey struct {
// Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -32,6 +36,9 @@ type VertexCompatKey struct {
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
}
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
func (k VertexCompatKey) GetBaseURL() string { return k.BaseURL }
// VertexCompatModel represents a model configuration for Vertex compatibility,
// including the actual model name and its alias for API routing.
type VertexCompatModel struct {
@@ -42,6 +49,9 @@ type VertexCompatModel struct {
Alias string `yaml:"alias" json:"alias"`
}
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {

View File

@@ -4,9 +4,11 @@
package logging
import (
"errors"
"fmt"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -14,11 +16,24 @@ import (
log "github.com/sirupsen/logrus"
)
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}
const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format.
// client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
//
// Returns:
// - gin.HandlerFunc: A middleware handler for request logging
@@ -28,6 +43,15 @@ func GinLogrusLogger() gin.HandlerFunc {
path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
if shouldSkipGinRequestLogging(c) {
@@ -49,23 +73,38 @@ func GinLogrusLogger() gin.HandlerFunc {
clientIP := c.ClientIP()
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}
entry := log.WithField("request_id", requestID)
switch {
case statusCode >= http.StatusInternalServerError:
log.Error(logLine)
entry.Error(logLine)
case statusCode >= http.StatusBadRequest:
log.Warn(logLine)
entry.Warn(logLine)
default:
log.Info(logLine)
entry.Info(logLine)
}
}
}
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
// and request path, then returns a 500 Internal Server Error response to the client.
@@ -74,6 +113,11 @@ func GinLogrusLogger() gin.HandlerFunc {
// - gin.HandlerFunc: A middleware handler for panic recovery
func GinLogrusRecovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
panic(http.ErrAbortHandler)
}
log.WithFields(log.Fields{
"panic": recovered,
"stack": string(debug.Stack()),

View File

@@ -0,0 +1,60 @@
package logging
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/abort", func(c *gin.Context) {
panic(http.ErrAbortHandler)
})
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
recorder := httptest.NewRecorder()
defer func() {
recovered := recover()
if recovered == nil {
t.Fatalf("expected panic, got nil")
}
err, ok := recovered.(error)
if !ok {
t.Fatalf("expected error panic, got %T", recovered)
}
if !errors.Is(err, http.ErrAbortHandler) {
t.Fatalf("expected ErrAbortHandler, got %v", err)
}
if err != http.ErrAbortHandler {
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
}
}()
engine.ServeHTTP(recorder, req)
}
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/panic", func(c *gin.Context) {
panic("boom")
})
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
recorder := httptest.NewRecorder()
engine.ServeHTTP(recorder, req)
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
}

View File

@@ -10,6 +10,7 @@ import (
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -24,9 +25,13 @@ var (
)
// LogFormatter defines a custom log format for logrus.
// This formatter adds timestamp, level, and source location to each log entry.
// This formatter adds timestamp, level, request ID, and source location to each log entry.
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{}
// logFieldOrder defines the display order for common log fields.
var logFieldOrder = []string{"provider", "model", "mode", "budget", "level", "original_mode", "original_value", "min", "max", "clamped_to", "error"}
// Format renders a single log entry with custom formatting.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var buffer *bytes.Buffer
@@ -38,16 +43,38 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
// Handle nil Caller (can happen with some log entries)
callerFile := "unknown"
callerLine := 0
if entry.Caller != nil {
callerFile = filepath.Base(entry.Caller.File)
callerLine = entry.Caller.Line
reqID := "--------"
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
reqID = id
}
level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", level)
// Build fields string (only print fields in logFieldOrder)
var fieldsStr string
if len(entry.Data) > 0 {
var fields []string
for _, k := range logFieldOrder {
if v, ok := entry.Data[k]; ok {
fields = append(fields, fmt.Sprintf("%s=%v", k, v))
}
}
if len(fields) > 0 {
fieldsStr = " " + strings.Join(fields, " ")
}
}
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s%s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message, fieldsStr)
} else {
formatted = fmt.Sprintf("[%s] [%s] [%s] %s%s\n", timestamp, reqID, levelStr, message, fieldsStr)
}
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
buffer.WriteString(formatted)
return buffer.Bytes(), nil
@@ -75,22 +102,57 @@ func SetupBaseLogger() {
})
}
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ResolveLogDirectory determines the directory used for application logs.
func ResolveLogDirectory(cfg *config.Config) string {
logDir := "logs"
if base := util.WritablePath(); base != "" {
return filepath.Join(base, "logs")
}
if cfg == nil {
return logDir
}
if !isDirWritable(logDir) {
authDir := strings.TrimSpace(cfg.AuthDir)
if authDir != "" {
logDir = filepath.Join(authDir, "logs")
}
}
return logDir
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger()
writerMu.Lock()
defer writerMu.Unlock()
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
}
logDir := ResolveLogDirectory(cfg)
protectedPath := ""
if loggingToFile {
if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err)
}
@@ -114,7 +176,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
log.SetOutput(os.Stdout)
}
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil
}

View File

@@ -43,10 +43,13 @@ type RequestLogger interface {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
@@ -55,11 +58,12 @@ type RequestLogger interface {
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
// IsEnabled returns whether request logging is currently enabled.
//
@@ -107,6 +111,12 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
//
// Parameters:
// - timestamp: The time when first response chunk was received
SetFirstChunkTimestamp(timestamp time.Time)
// Close finalizes the log file and cleans up resources.
//
// Returns:
@@ -177,20 +187,23 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force {
return nil
}
@@ -200,10 +213,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
if force && !l.enabled {
filename = l.generateErrorFilename(url)
filename = l.generateErrorFilename(url, requestID)
}
filePath := filepath.Join(l.logsDir, filename)
@@ -244,6 +257,8 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
responseHeaders,
responseToWrite,
decompressErr,
requestTimestamp,
apiResponseTimestamp,
)
if errClose := logFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request log file")
@@ -271,11 +286,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
if !l.enabled {
return &NoOpStreamingLogWriter{}, nil
}
@@ -285,8 +301,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
return nil, fmt.Errorf("failed to create logs directory: %w", err)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
filePath := filepath.Join(l.logsDir, filename)
requestHeaders := make(map[string][]string, len(headers))
@@ -330,8 +346,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
}
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url))
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
}
// ensureLogsDir creates the logs directory if it doesn't exist.
@@ -346,13 +362,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
}
// generateFilename creates a sanitized filename from the URL path and current timestamp.
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
//
// Parameters:
// - url: The request URL
// - requestID: Optional request ID to include in filename
//
// Returns:
// - string: A sanitized filename for the log file
func (l *FileRequestLogger) generateFilename(url string) string {
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
// Extract path from URL
path := url
if strings.Contains(url, "?") {
@@ -368,12 +386,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
sanitized := l.sanitizeForFilename(path)
// Add timestamp
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
timestamp = strings.Replace(timestamp, ".", "", -1)
timestamp := time.Now().Format("2006-01-02T150405")
id := requestLogID.Add(1)
// Use request ID if provided, otherwise use sequential ID
var idPart string
if len(requestID) > 0 && requestID[0] != "" {
idPart = requestID[0]
} else {
id := requestLogID.Add(1)
idPart = fmt.Sprintf("%d", id)
}
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
}
// sanitizeForFilename replaces characters that are not safe for filenames.
@@ -487,17 +511,22 @@ func (l *FileRequestLogger) writeNonStreamingLog(
responseHeaders map[string][]string,
response []byte,
decompressErr error,
requestTimestamp time.Time,
apiResponseTimestamp time.Time,
) error {
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
if requestTimestamp.IsZero() {
requestTimestamp = time.Now()
}
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
@@ -571,7 +600,7 @@ func writeRequestInfoWithBody(
return nil
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 {
return nil
}
@@ -589,6 +618,11 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite
}
if !timestamp.IsZero() {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
}
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
@@ -962,6 +996,9 @@ type FileStreamingLogWriter struct {
// apiResponse stores the upstream API response data.
apiResponse []byte
// apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time
}
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
@@ -1041,6 +1078,12 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp
}
}
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
@@ -1128,10 +1171,10 @@ func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse, w.apiResponseTimestamp); errWrite != nil {
return errWrite
}
@@ -1208,6 +1251,8 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil.
//
// Returns:

View File

@@ -0,0 +1,61 @@
package logging
import (
"context"
"crypto/rand"
"encoding/hex"
"github.com/gin-gonic/gin"
)
// requestIDKey is the context key for storing/retrieving request IDs.
type requestIDKey struct{}
// ginRequestIDKey is the Gin context key for request IDs.
const ginRequestIDKey = "__request_id__"
// GenerateRequestID creates a new 8-character hex request ID.
func GenerateRequestID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return "00000000"
}
return hex.EncodeToString(b)
}
// WithRequestID returns a new context with the request ID attached.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// GetRequestID retrieves the request ID from the context.
// Returns empty string if not found.
func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// SetGinRequestID stores the request ID in the Gin context.
func SetGinRequestID(c *gin.Context, requestID string) {
if c != nil {
c.Set(ginRequestIDKey, requestID)
}
}
// GetGinRequestID retrieves the request ID from the Gin context.
func GetGinRequestID(c *gin.Context) string {
if c == nil {
return ""
}
if id, exists := c.Get(ginRequestIDKey); exists {
if s, ok := id.(string); ok {
return s
}
}
return ""
}

View File

@@ -24,10 +24,11 @@ import (
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
return
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if err := os.MkdirAll(staticDir, 0o755); err != nil {
log.WithError(err).Warn("failed to prepare static directory for management asset")
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localPath := filepath.Join(staticDir, managementAssetName)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset")
return
}
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {

View File

@@ -7,12 +7,93 @@ import (
"embed"
_ "embed"
"strings"
"sync/atomic"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// codexInstructionsEnabled controls whether CodexInstructionsForModel returns official instructions.
// When false (default), CodexInstructionsForModel returns (true, "") immediately.
// Set via SetCodexInstructionsEnabled from config.
var codexInstructionsEnabled atomic.Bool
// SetCodexInstructionsEnabled sets whether codex instructions processing is enabled.
func SetCodexInstructionsEnabled(enabled bool) {
codexInstructionsEnabled.Store(enabled)
}
// GetCodexInstructionsEnabled returns whether codex instructions processing is enabled.
func GetCodexInstructionsEnabled() bool {
return codexInstructionsEnabled.Load()
}
//go:embed codex_instructions
var codexInstructionsDir embed.FS
func CodexInstructionsForModel(modelName, systemInstructions string) (bool, string) {
//go:embed opencode_codex_instructions.txt
var opencodeCodexInstructions string
const (
codexUserAgentKey = "__cpa_user_agent"
userAgentOpenAISDK = "opencode/"
)
func InjectCodexUserAgent(raw []byte, userAgent string) []byte {
if len(raw) == 0 {
return raw
}
trimmed := strings.TrimSpace(userAgent)
if trimmed == "" {
return raw
}
updated, err := sjson.SetBytes(raw, codexUserAgentKey, trimmed)
if err != nil {
return raw
}
return updated
}
func ExtractCodexUserAgent(raw []byte) string {
if len(raw) == 0 {
return ""
}
return strings.TrimSpace(gjson.GetBytes(raw, codexUserAgentKey).String())
}
func StripCodexUserAgent(raw []byte) []byte {
if len(raw) == 0 {
return raw
}
if !gjson.GetBytes(raw, codexUserAgentKey).Exists() {
return raw
}
updated, err := sjson.DeleteBytes(raw, codexUserAgentKey)
if err != nil {
return raw
}
return updated
}
func codexInstructionsForOpenCode(systemInstructions string) (bool, string) {
if opencodeCodexInstructions == "" {
return false, ""
}
if strings.HasPrefix(systemInstructions, opencodeCodexInstructions) {
return true, ""
}
return false, opencodeCodexInstructions
}
func useOpenCodeInstructions(userAgent string) bool {
return strings.Contains(strings.ToLower(userAgent), userAgentOpenAISDK)
}
func IsOpenCodeUserAgent(userAgent string) bool {
return useOpenCodeInstructions(userAgent)
}
func codexInstructionsForCodex(modelName, systemInstructions string) (bool, string) {
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
lastPrompt := ""
@@ -57,3 +138,13 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
return false, lastPrompt
}
}
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
if !GetCodexInstructionsEnabled() {
return true, ""
}
if IsOpenCodeUserAgent(userAgent) {
return codexInstructionsForOpenCode(systemInstructions)
}
return codexInstructionsForCodex(modelName, systemInstructions)
}

View File

@@ -0,0 +1,79 @@
You are OpenCode, the best coding agent on the planet.
You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Only add comments if they are necessary to make a non-obvious block easier to understand.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
## Tool usage
- Prefer specialized tools over shell for file operations:
- Use Read to view files, Edit to modify files, and Write only when needed.
- Use Glob to find files by name and Grep to search file contents.
- Use Bash for terminal operations (git, bun, builds, tests, running scripts).
- Run tool calls in parallel when neither call needs the others output; otherwise run sequentially.
## Git and workspace hygiene
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend commits unless explicitly requested.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into bland, generic layouts.
Aim for interfaces that feel intentional and deliberate.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile.
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Default: do the work without asking questions. Treat short tasks as sufficient direction; infer missing details by reading the codebase and following existing conventions.
- Questions: only ask when you are truly blocked after checking relevant context AND you cannot safely pick a reasonable default. This usually means one of:
* The request is ambiguous in a way that materially changes the result and you cannot disambiguate by reading the repo.
* The action is destructive/irreversible, touches production, or changes billing/security posture.
* You need a secret/credential/value that cannot be inferred (API key, account id, etc.).
- If you must ask: do all non-blocked work first, then ask exactly one targeted question, include your recommended default, and state what would change based on the answer.
- Never ask permission questions like "Should I proceed?" or "Do you want me to run tests?"; proceed with the most reasonable option and mention what you did.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No "save/copy this file" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
## Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no "above/below"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5

View File

@@ -0,0 +1,303 @@
// Package registry provides Kiro model conversion utilities.
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
// and merging with static metadata for thinking support and other capabilities.
package registry
import (
"strings"
"time"
)
// KiroAPIModel represents a model from Kiro API response.
// This is a local copy to avoid import cycles with the kiro package.
// The structure mirrors kiro.KiroModel for easy data conversion.
type KiroAPIModel struct {
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
ModelID string
// ModelName is the human-readable name
ModelName string
// Description is the model description
Description string
// RateMultiplier is the credit multiplier for this model
RateMultiplier float64
// RateUnit is the unit for rate calculation (e.g., "credit")
RateUnit string
// MaxInputTokens is the maximum input token limit
MaxInputTokens int
}
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
// All Kiro models support thinking with the following budget range.
var DefaultKiroThinkingSupport = &ThinkingSupport{
Min: 1024, // Minimum thinking budget tokens
Max: 32000, // Maximum thinking budget tokens
ZeroAllowed: true, // Allow disabling thinking with 0
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
}
// DefaultKiroContextLength is the default context window size for Kiro models.
const DefaultKiroContextLength = 200000
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
const DefaultKiroMaxCompletionTokens = 64000
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
// It performs the following transformations:
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
// - Adds default thinking support metadata
// - Sets default context length and max completion tokens if not provided
//
// Parameters:
// - kiroModels: List of models from Kiro API response
//
// Returns:
// - []*ModelInfo: Converted model information list
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
if len(kiroModels) == 0 {
return nil
}
now := time.Now().Unix()
result := make([]*ModelInfo, 0, len(kiroModels))
for _, km := range kiroModels {
// Skip nil models
if km == nil {
continue
}
// Skip models without valid ID
if km.ModelID == "" {
continue
}
// Normalize the model ID to kiro-* format
normalizedID := normalizeKiroModelID(km.ModelID)
// Create ModelInfo with converted data
info := &ModelInfo{
ID: normalizedID,
Object: "model",
Created: now,
OwnedBy: "aws",
Type: "kiro",
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
Description: km.Description,
// Use MaxInputTokens from API if available, otherwise use default
ContextLength: getContextLength(km.MaxInputTokens),
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
// All Kiro models support thinking
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
}
result = append(result, info)
}
return result
}
// GenerateAgenticVariants creates -agentic variants for each model.
// Agentic variants are optimized for coding agents with chunked writes.
//
// Parameters:
// - models: Base models to generate variants for
//
// Returns:
// - []*ModelInfo: Combined list of base models and their agentic variants
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
// Pre-allocate result with capacity for both base models and variants
result := make([]*ModelInfo, 0, len(models)*2)
for _, model := range models {
if model == nil {
continue
}
// Add the base model first
result = append(result, model)
// Skip if model already has -agentic suffix
if strings.HasSuffix(model.ID, "-agentic") {
continue
}
// Skip special models that shouldn't have agentic variants
if model.ID == "kiro-auto" {
continue
}
// Create agentic variant
agenticModel := &ModelInfo{
ID: model.ID + "-agentic",
Object: model.Object,
Created: model.Created,
OwnedBy: model.OwnedBy,
Type: model.Type,
DisplayName: model.DisplayName + " (Agentic)",
Description: generateAgenticDescription(model.Description),
ContextLength: model.ContextLength,
MaxCompletionTokens: model.MaxCompletionTokens,
Thinking: cloneThinkingSupport(model.Thinking),
}
result = append(result, agenticModel)
}
return result
}
// MergeWithStaticMetadata merges dynamic models with static metadata.
// Static metadata takes priority for any overlapping fields.
// This allows manual overrides for specific models while keeping dynamic discovery.
//
// Parameters:
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
// - staticModels: Predefined model metadata (from GetKiroModels())
//
// Returns:
// - []*ModelInfo: Merged model list with static metadata taking priority
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
if len(dynamicModels) == 0 && len(staticModels) == 0 {
return nil
}
// Build a map of static models for quick lookup
staticMap := make(map[string]*ModelInfo, len(staticModels))
for _, sm := range staticModels {
if sm != nil && sm.ID != "" {
staticMap[sm.ID] = sm
}
}
// Build result, preferring static metadata where available
seenIDs := make(map[string]struct{})
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
// First, process dynamic models and merge with static if available
for _, dm := range dynamicModels {
if dm == nil || dm.ID == "" {
continue
}
// Skip duplicates
if _, seen := seenIDs[dm.ID]; seen {
continue
}
seenIDs[dm.ID] = struct{}{}
// Check if static metadata exists for this model
if sm, exists := staticMap[dm.ID]; exists {
// Static metadata takes priority - use static model
result = append(result, sm)
} else {
// No static metadata - use dynamic model
result = append(result, dm)
}
}
// Add any static models not in dynamic list
for _, sm := range staticModels {
if sm == nil || sm.ID == "" {
continue
}
if _, seen := seenIDs[sm.ID]; seen {
continue
}
seenIDs[sm.ID] = struct{}{}
result = append(result, sm)
}
return result
}
// normalizeKiroModelID converts Kiro API model IDs to internal format.
// Transformation rules:
// - Adds "kiro-" prefix if not present
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
// - Handles special cases like "auto" → "kiro-auto"
//
// Examples:
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
// - "auto" → "kiro-auto"
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
func normalizeKiroModelID(modelID string) string {
if modelID == "" {
return ""
}
// Trim whitespace
modelID = strings.TrimSpace(modelID)
// Replace dots with hyphens (e.g., 4.5 → 4-5)
normalized := strings.ReplaceAll(modelID, ".", "-")
// Add kiro- prefix if not present
if !strings.HasPrefix(normalized, "kiro-") {
normalized = "kiro-" + normalized
}
return normalized
}
// generateKiroDisplayName creates a human-readable display name.
// Uses the API-provided model name if available, otherwise generates from ID.
func generateKiroDisplayName(modelName, normalizedID string) string {
if modelName != "" {
return "Kiro " + modelName
}
// Generate from normalized ID by removing kiro- prefix and formatting
displayID := strings.TrimPrefix(normalizedID, "kiro-")
// Capitalize first letter of each word
words := strings.Split(displayID, "-")
for i, word := range words {
if len(word) > 0 {
words[i] = strings.ToUpper(word[:1]) + word[1:]
}
}
return "Kiro " + strings.Join(words, " ")
}
// generateAgenticDescription creates description for agentic variants.
func generateAgenticDescription(baseDescription string) string {
if baseDescription == "" {
return "Optimized for coding agents with chunked writes"
}
return baseDescription + " (Agentic mode: chunked writes)"
}
// getContextLength returns the context length, using default if not provided.
func getContextLength(maxInputTokens int) int {
if maxInputTokens > 0 {
return maxInputTokens
}
return DefaultKiroContextLength
}
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
// Returns nil if input is nil.
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
if ts == nil {
return nil
}
clone := &ThinkingSupport{
Min: ts.Min,
Max: ts.Max,
ZeroAllowed: ts.ZeroAllowed,
DynamicAllowed: ts.DynamicAllowed,
}
// Deep copy Levels slice if present
if len(ts.Levels) > 0 {
clone.Levels = make([]string, len(ts.Levels))
copy(clone.Levels, ts.Levels)
}
return clone
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,846 @@
// Package registry provides model definitions for various AI service providers.
// This file stores the static model metadata catalog.
package registry
// GetClaudeModels returns the standard Claude model definitions
func GetClaudeModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "claude-haiku-4-5-20251001",
Object: "model",
Created: 1759276800, // 2025-10-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Haiku",
ContextLength: 200000,
MaxCompletionTokens: 64000,
// Thinking: not supported for Haiku models
},
{
ID: "claude-sonnet-4-5-20250929",
Object: "model",
Created: 1759104000, // 2025-09-29
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
Created: 1761955200, // 2025-11-01
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.5 Opus",
Description: "Premium model combining maximum intelligence with practical performance",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-1-20250805",
Object: "model",
Created: 1722945600, // 2025-08-05
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.1 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-opus-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Opus",
ContextLength: 200000,
MaxCompletionTokens: 32000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-3-7-sonnet-20250219",
Object: "model",
Created: 1708300800, // 2025-02-19
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.7 Sonnet",
ContextLength: 128000,
MaxCompletionTokens: 8192,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "claude-3-5-haiku-20241022",
Object: "model",
Created: 1729555200, // 2024-10-22
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.5 Haiku",
ContextLength: 128000,
MaxCompletionTokens: 8192,
// Thinking: not supported for Haiku models
},
}
}
// GetGeminiModels returns the standard Gemini model definitions
func GetGeminiModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-pro",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash-lite",
Object: "model",
Created: 1753142400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-lite",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Lite",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-3-pro-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Preview",
Description: "Gemini 3 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: 1765929600,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Gemini 3 Flash Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-image-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Image Preview",
Description: "Gemini 3 Pro Image Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
}
}
func GetGeminiVertexModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-pro",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash-lite",
Object: "model",
Created: 1753142400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-lite",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Lite",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-3-pro-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Preview",
Description: "Gemini 3 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: 1765929600,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-image-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Image Preview",
Description: "Gemini 3 Pro Image Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
// Imagen image generation models - use :predict action
{
ID: "imagen-4.0-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Generate",
Description: "Imagen 4.0 image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-4.0-ultra-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-ultra-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Ultra Generate",
Description: "Imagen 4.0 Ultra high-quality image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-3.0-generate-002",
Object: "model",
Created: 1740000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-3.0-generate-002",
Version: "3.0",
DisplayName: "Imagen 3.0 Generate",
Description: "Imagen 3.0 image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-3.0-fast-generate-001",
Object: "model",
Created: 1740000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-3.0-fast-generate-001",
Version: "3.0",
DisplayName: "Imagen 3.0 Fast Generate",
Description: "Imagen 3.0 fast image generation model",
SupportedGenerationMethods: []string{"predict"},
},
{
ID: "imagen-4.0-fast-generate-001",
Object: "model",
Created: 1750000000,
OwnedBy: "google",
Type: "gemini",
Name: "models/imagen-4.0-fast-generate-001",
Version: "4.0",
DisplayName: "Imagen 4.0 Fast Generate",
Description: "Imagen 4.0 fast image generation model",
SupportedGenerationMethods: []string{"predict"},
},
}
}
// GetGeminiCLIModels returns the standard Gemini model definitions
func GetGeminiCLIModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-pro",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash-lite",
Object: "model",
Created: 1753142400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-lite",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Lite",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-3-pro-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Preview",
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: 1765929600,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
}
}
// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations
func GetAIStudioModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-pro",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-2.5-flash-lite",
Object: "model",
Created: 1753142400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-lite",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Lite",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-3-pro-preview",
Object: "model",
Created: 1737158400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-pro-preview",
Version: "3.0",
DisplayName: "Gemini 3 Pro Preview",
Description: "Gemini 3 Pro Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: 1765929600,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-pro-latest",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-pro-latest",
Version: "2.5",
DisplayName: "Gemini Pro Latest",
Description: "Latest release of Gemini Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-flash-latest",
Object: "model",
Created: 1750118400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-flash-latest",
Version: "2.5",
DisplayName: "Gemini Flash Latest",
Description: "Latest release of Gemini Flash",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "gemini-flash-lite-latest",
Object: "model",
Created: 1753142400,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-flash-lite-latest",
Version: "2.5",
DisplayName: "Gemini Flash-Lite Latest",
Description: "Latest release of Gemini Flash-Lite",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 512, Max: 24576, ZeroAllowed: true, DynamicAllowed: true},
},
// {
// ID: "gemini-2.5-flash-image-preview",
// Object: "model",
// Created: 1756166400,
// OwnedBy: "google",
// Type: "gemini",
// Name: "models/gemini-2.5-flash-image-preview",
// Version: "2.5",
// DisplayName: "Gemini 2.5 Flash Image Preview",
// Description: "State-of-the-art image generation and editing model.",
// InputTokenLimit: 1048576,
// OutputTokenLimit: 8192,
// SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
// // image models don't support thinkingConfig; leave Thinking nil
// },
{
ID: "gemini-2.5-flash-image",
Object: "model",
Created: 1759363200,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-image",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Image",
Description: "State-of-the-art image generation and editing model.",
InputTokenLimit: 1048576,
OutputTokenLimit: 8192,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
// image models don't support thinkingConfig; leave Thinking nil
},
}
}
// GetOpenAIModels returns the standard OpenAI model definitions
func GetOpenAIModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
Created: 1754524800,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5-2025-08-07",
DisplayName: "GPT 5",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gpt-5-codex",
Object: "model",
Created: 1757894400,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5-2025-09-15",
DisplayName: "GPT 5 Codex",
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-codex-mini",
Object: "model",
Created: 1762473600,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5-2025-11-07",
DisplayName: "GPT 5 Codex Mini",
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1",
Object: "model",
Created: 1762905600,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex",
Object: "model",
Created: 1762905600,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5.1 Codex",
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-mini",
Object: "model",
Created: 1762905600,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5.1 Codex Mini",
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-max",
Object: "model",
Created: 1763424000,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5.1 Codex Max",
Description: "Stable version of GPT 5.1 Codex Max",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2",
Object: "model",
Created: 1765440000,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.2",
DisplayName: "GPT 5.2",
Description: "Stable version of GPT 5.2",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2-codex",
Object: "model",
Created: 1765440000,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.2",
DisplayName: "GPT 5.2 Codex",
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
}
}
// GetQwenModels returns the standard Qwen model definitions
func GetQwenModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "qwen3-coder-plus",
Object: "model",
Created: 1753228800,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Coder Plus",
Description: "Advanced code generation and understanding model",
ContextLength: 32768,
MaxCompletionTokens: 8192,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "qwen3-coder-flash",
Object: "model",
Created: 1753228800,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Coder Flash",
Description: "Fast code generation model",
ContextLength: 8192,
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
Created: 1758672000,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Vision Model",
Description: "Vision model model",
ContextLength: 32768,
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
}
}
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
// Uses level-based configuration so standard normalization flows apply before conversion.
var iFlowThinkingSupport = &ThinkingSupport{
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
}
// GetIFlowModels returns supported models for iFlow OAuth accounts.
func GetIFlowModels() []*ModelInfo {
entries := []struct {
ID string
DisplayName string
Description string
Created int64
Thinking *ThinkingSupport
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {
models = append(models, &ModelInfo{
ID: entry.ID,
Object: "model",
Created: entry.Created,
OwnedBy: "iflow",
Type: "iflow",
DisplayName: entry.DisplayName,
Description: entry.Description,
Thinking: entry.Thinking,
})
}
return models
}
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport
MaxCompletionTokens int
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use upstream model names returned by the Antigravity models endpoint.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}
}

View File

@@ -4,6 +4,7 @@
package registry
import (
"context"
"fmt"
"sort"
"strings"
@@ -46,10 +47,17 @@ type ModelInfo struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
Thinking *ThinkingSupport `json:"thinking,omitempty"`
// UserDefined indicates this model was defined through config file's models[]
// array (e.g., openai-compatibility.*.models[], *-api-key.models[]).
// UserDefined models have thinking configuration passed through without validation.
UserDefined bool `json:"-"`
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
@@ -72,6 +80,8 @@ type ThinkingSupport struct {
type ModelRegistration struct {
// Info contains the model metadata
Info *ModelInfo
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
InfoByProvider map[string]*ModelInfo
// Count is the number of active clients that can provide this model
Count int
// LastUpdated tracks when this registration was last modified
@@ -84,6 +94,13 @@ type ModelRegistration struct {
SuspendedClients map[string]string
}
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
type ModelRegistryHook interface {
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
OnModelsUnregistered(ctx context.Context, provider, clientID string)
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
@@ -97,6 +114,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
// Global model registry instance
@@ -117,6 +136,71 @@ func GetGlobalRegistry() *ModelRegistry {
return globalRegistry
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return nil
}
p := ""
if len(provider) > 0 {
p = strings.ToLower(strings.TrimSpace(provider[0]))
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info
}
return LookupStaticModelInfo(modelID)
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.hook = hook
}
const defaultModelRegistryHookTimeout = 5 * time.Second
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
if hook == nil {
return
}
modelsCopy := cloneModelInfosUnique(models)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
}()
}
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
hook := r.hook
if hook == nil {
return
}
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsUnregistered(ctx, provider, clientID)
}()
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
@@ -177,6 +261,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
return
@@ -219,6 +304,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= toRemove {
delete(reg.Providers, oldProvider)
if reg.InfoByProvider != nil {
delete(reg.InfoByProvider, oldProvider)
}
} else {
reg.Providers[oldProvider] = count - toRemove
}
@@ -268,6 +356,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
model := newModels[id]
if reg, ok := r.models[id]; ok {
reg.Info = cloneModelInfo(model)
if provider != "" {
if reg.InfoByProvider == nil {
reg.InfoByProvider = make(map[string]*ModelInfo)
}
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
@@ -310,6 +404,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
@@ -330,11 +425,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if existing.InfoByProvider == nil {
existing.InfoByProvider = make(map[string]*ModelInfo)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
existing.InfoByProvider[provider] = cloneModelInfo(model)
}
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
return
@@ -342,6 +441,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
registration := &ModelRegistration{
Info: cloneModelInfo(model),
InfoByProvider: make(map[string]*ModelInfo),
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
@@ -349,6 +449,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
registration.InfoByProvider[provider] = cloneModelInfo(model)
}
r.models[modelID] = registration
log.Debugf("Registered new model %s from provider %s", modelID, provider)
@@ -374,6 +475,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
@@ -397,9 +501,31 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
}
return &copyModel
}
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
cloned = append(cloned, cloneModelInfo(model))
}
return cloned
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
@@ -436,6 +562,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
@@ -460,6 +589,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
r.triggerModelsUnregistered(provider, clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
@@ -625,6 +755,131 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
return models
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
// Parameters:
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
//
// Returns:
// - []*ModelInfo: List of available models for the provider
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return nil
}
r.mutex.RLock()
defer r.mutex.RUnlock()
type providerModel struct {
count int
info *ModelInfo
}
providerModels := make(map[string]*providerModel)
for clientID, clientProvider := range r.clientProviders {
if clientProvider != provider {
continue
}
modelIDs := r.clientModels[clientID]
if len(modelIDs) == 0 {
continue
}
clientInfos := r.clientModelInfos[clientID]
for _, modelID := range modelIDs {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
entry := providerModels[modelID]
if entry == nil {
entry = &providerModel{}
providerModels[modelID] = entry
}
entry.count++
if entry.info == nil {
if clientInfos != nil {
if info := clientInfos[modelID]; info != nil {
entry.info = info
}
}
if entry.info == nil {
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
entry.info = reg.Info
}
}
}
}
}
if len(providerModels) == 0 {
return nil
}
quotaExpiredDuration := 5 * time.Minute
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
for modelID, entry := range providerModels {
if entry == nil || entry.count <= 0 {
continue
}
registration, ok := r.models[modelID]
expiredClients := 0
cooldownSuspended := 0
otherSuspended := 0
if ok && registration != nil {
if registration.QuotaExceededClients != nil {
for clientID, quotaTime := range registration.QuotaExceededClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
}
if registration.SuspendedClients != nil {
for clientID, reason := range registration.SuspendedClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
}
availableClients := entry.count
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
}
}
}
return result
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check
@@ -716,12 +971,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
return result
}
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
// Returns nil if the model is unknown to the registry.
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
r.mutex.RLock()
defer r.mutex.RUnlock()
if reg, ok := r.models[modelID]; ok && reg != nil {
// Try provider specific definition first
if provider != "" && reg.InfoByProvider != nil {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info
}
}
}
}
// Fallback to global info (last registered)
return reg.Info
}
return nil
@@ -764,6 +1029,9 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
if len(model.SupportedEndpoints) > 0 {
result["supported_endpoints"] = model.SupportedEndpoints
}
return result
case "claude", "kiro", "antigravity":
@@ -774,10 +1042,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
result["created_at"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
result["type"] = "model"
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName

View File

@@ -0,0 +1,204 @@
package registry
import (
"context"
"sync"
"testing"
"time"
)
func newTestModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
}
type registeredCall struct {
provider string
clientID string
models []*ModelInfo
}
type unregisteredCall struct {
provider string
clientID string
}
type capturingHook struct {
registeredCh chan registeredCall
unregisteredCh chan unregisteredCall
}
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
}
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
}
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
inputModels := []*ModelInfo{
{ID: "m1", DisplayName: "Model One"},
{ID: "m2", DisplayName: "Model Two"},
}
r.RegisterClient("client-1", "OpenAI", inputModels)
select {
case call := <-hook.registeredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
if len(call.models) != 2 {
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
}
if call.models[0] == nil || call.models[0].ID != "m1" {
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
}
if call.models[1] == nil || call.models[1].ID != "m2" {
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
}
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
r.UnregisterClient("client-1")
select {
case call := <-hook.unregisteredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
type blockingHook struct {
started chan struct{}
unblock chan struct{}
}
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
select {
case <-h.started:
default:
close(h.started)
}
<-h.unblock
}
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
r := newTestModelRegistry()
hook := &blockingHook{
started: make(chan struct{}),
unblock: make(chan struct{}),
}
r.SetHook(hook)
defer close(hook.unblock)
done := make(chan struct{})
go func() {
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
close(done)
}()
select {
case <-hook.started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for hook to start")
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("RegisterClient appears to be blocked by hook")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
}
type panicHook struct {
registeredCalled chan struct{}
unregisteredCalled chan struct{}
}
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
if h.registeredCalled != nil {
h.registeredCalled <- struct{}{}
}
panic("boom")
}
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
if h.unregisteredCalled != nil {
h.unregisteredCalled <- struct{}{}
}
panic("boom")
}
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
r := newTestModelRegistry()
hook := &panicHook{
registeredCalled: make(chan struct{}, 1),
unregisteredCalled: make(chan struct{}, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
r.UnregisterClient("client-1")
select {
case <-hook.unregisteredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}

Some files were not shown because too many files have changed in this diff Show More