Compare commits

...

113 Commits

Author SHA1 Message Date
Luis Pater
b27a175fef Merge pull request #17 from Ravens2121/master
由AI进行更改修复了Kiro供应商的Claude协议与OpenAI协议。(对比AIClient-2-API项目进行变更)
2025-12-11 01:31:47 +08:00
Ravens
8d5f89ccfd fix(kiro): fix translator format mismatch for OpenAI protocol
Amp-Thread-ID: https://ampcode.com/threads/T-019b092b-f2de-72a1-b428-72511c0de628
Co-authored-by: Amp <amp@ampcode.com>
2025-12-11 01:15:00 +08:00
Ravens
084e2666cb fix(kiro): add SSE event: prefix for Claude client compatibility
Amp-Thread-ID: https://ampcode.com/threads/T-019b08fc-ff96-766e-a942-63dd35ed28c6
Co-authored-by: Amp <amp@ampcode.com>
2025-12-11 00:14:20 +08:00
Luis Pater
2eb2dbb266 Merge branch 'router-for-me:main' into main 2025-12-10 23:56:59 +08:00
Luis Pater
e717939edb Fixed: #478
feat(antigravity): add support for inline image data in client responses
2025-12-10 23:55:53 +08:00
Ravens2121
7758a86d1e Merge branch 'router-for-me:main' into main 2025-12-10 23:25:23 +08:00
Luis Pater
9f72a875f8 Merge branch 'router-for-me:main' into main 2025-12-10 16:54:31 +08:00
Luis Pater
94d61c7b2b fix(logging): update response aggregation logic to include all attempts 2025-12-10 16:53:48 +08:00
Luis Pater
f999650322 Merge branch 'router-for-me:main' into main 2025-12-10 16:04:18 +08:00
Luis Pater
1249b07eb8 feat(responses): add unique identifiers for responses, function calls, and tool uses 2025-12-10 16:02:54 +08:00
Luis Pater
38319b0483 Merge branch 'router-for-me:main' into main 2025-12-10 15:29:19 +08:00
Luis Pater
6b37f33d31 feat(antigravity): add unique identifier for tool use blocks in response 2025-12-10 15:27:57 +08:00
Luis Pater
af543238aa Merge pull request #16 from fuko2935/fix/remove-unstable-kiro-auto
fix(registry): remove unstable kiro-auto model
2025-12-10 00:23:31 +08:00
Luis Pater
2de27c560b Merge pull request #15 from StriveMario/feature/fix-kiro-refreshtoken
fix kiro cannot refresh the token
2025-12-10 00:22:36 +08:00
Luis Pater
773ed6cc64 Merge branch 'router-for-me:main' into main 2025-12-10 00:17:43 +08:00
fuko2935
a594338bc5 fix(registry): remove unstable kiro-auto model
- Removes kiro-auto from static model registry
- Removes kiro-auto mapping from executor
- Fixes compatibility issues reported in #7

Fixes #7
2025-12-09 19:14:40 +03:00
Luis Pater
f25f419e5a fix(antigravity): remove references to autopush endpoint and update fallback logic 2025-12-10 00:13:20 +08:00
Luis Pater
1fd1ccca17 Merge branch 'router-for-me:main' into main 2025-12-09 21:13:08 +08:00
Luis Pater
b7e382008f Merge pull request #465 from router-for-me/think
Move thinking budget normalization from translators to executor
2025-12-09 21:10:33 +08:00
hkfires
70d6b95097 feat(amp): add /news.rss proxy route 2025-12-09 21:05:06 +08:00
hkfires
9b202b6c1c fix(executor): centralize default thinking config 2025-12-09 21:05:06 +08:00
hkfires
6a66b6801a feat(executor): enforce minimum thinking budget for antigravity models 2025-12-09 21:05:06 +08:00
hkfires
5b6d201408 refactor(translator): remove thinking budget normalization across all translators 2025-12-09 21:05:06 +08:00
hkfires
5ec9b5e5a9 feat(executor): normalize thinking budget across all Gemini executors 2025-12-09 21:05:06 +08:00
Luis Pater
5db3b58717 Merge pull request #470 from router-for-me/agry
fix(gemini): normalize model listing output
2025-12-09 21:00:29 +08:00
Mario
1fa5514d56 fix kiro cannot refresh the token 2025-12-09 20:13:16 +08:00
hkfires
347769b3e3 fix(openai-compat): use model id for auth model display 2025-12-09 18:09:14 +08:00
hkfires
3cfe7008a2 fix(registry): update gpt 5.1 model names 2025-12-09 17:55:21 +08:00
Luis Pater
c353606860 Merge pull request #14 from router-for-me/plus
v6.5.59
2025-12-09 17:51:46 +08:00
Luis Pater
2ba31ecc2d Merge branch 'main' into plus 2025-12-09 17:51:18 +08:00
hkfires
da23ddb061 fix(gemini): normalize model listing output 2025-12-09 17:34:15 +08:00
Luis Pater
39b6b3b289 Fixed: #463
fix(antigravity): remove `$ref` and `$defs` from JSON during key deletion
2025-12-09 17:32:17 +08:00
Luis Pater
c600519fa4 refactor(logging): replace log.Fatalf with log.Errorf and add error handling paths 2025-12-09 17:16:30 +08:00
hkfires
e5312fb5a2 feat(antigravity): support canonical names for antigravity models 2025-12-09 16:54:13 +08:00
Luis Pater
36380846a7 Merge branch 'router-for-me:main' into main 2025-12-09 10:03:16 +08:00
Luis Pater
92df0cada9 Merge pull request #461 from router-for-me/aistudio
feat(aistudio): normalize thinking budget in request translation
2025-12-09 08:41:46 +08:00
hkfires
96b55acff8 feat(aistudio): normalize thinking budget in request translation 2025-12-09 08:27:44 +08:00
Luis Pater
608cd8ee3d Merge pull request #13 from router-for-me/v6.5.57
v6.5.57
2025-12-08 23:33:52 +08:00
Luis Pater
9f41894573 Merge branch 'main' into v6.5.57 2025-12-08 23:33:39 +08:00
Luis Pater
bb45fee1cf Merge remote-tracking branch 'origin/dev' into dev 2025-12-08 23:28:22 +08:00
Luis Pater
af00304b0c fix(antigravity): remove exclusiveMaximum from JSON during key deletion 2025-12-08 23:28:01 +08:00
vuonglv(Andy)
5c3a013cd1 feat(config): add configurable host binding for server (#454)
* feat(config): add configurable host binding for server
2025-12-08 23:16:39 +08:00
Luis Pater
ab9e9442ec v6.5.56 (#12)
* feat(api): add comprehensive ampcode management endpoints

Add new REST API endpoints under /v0/management/ampcode for managing
ampcode configuration including upstream URL, API key, localhost
restriction, model mappings, and force model mappings settings.

- Move force-model-mappings from config_basic to config_lists
- Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings
- Support model mapping CRUD with upsert (PATCH) capability
- Add comprehensive test coverage for all ampcode endpoints

* refactor(api): simplify request body parsing in ampcode handlers

* feat(logging): add upstream API request/response capture to streaming logs

* style(logging): remove redundant separator line from response section

* feat(antigravity): enforce thinking budget limits for Claude models

* refactor(logging): remove unused variable in `ensureAttempt` and redundant function call

---------

Co-authored-by: hkfires <10558748+hkfires@users.noreply.github.com>
2025-12-08 22:32:29 +08:00
Luis Pater
6ad188921c refactor(logging): remove unused variable in ensureAttempt and redundant function call 2025-12-08 22:25:58 +08:00
Luis Pater
15ed98d6a9 Merge pull request #458 from router-for-me/agry
feat(antigravity): enforce thinking budget limits for Claude models
2025-12-08 20:55:52 +08:00
hkfires
a283545b6b feat(antigravity): enforce thinking budget limits for Claude models 2025-12-08 20:36:17 +08:00
Luis Pater
3efbd865a8 Merge pull request #457 from router-for-me/requestlog
style(logging): remove redundant separator line from response section
2025-12-08 18:21:24 +08:00
hkfires
aee659fb66 style(logging): remove redundant separator line from response section 2025-12-08 18:18:33 +08:00
Luis Pater
5aa386d8b9 Merge pull request #453 from router-for-me/amp
add ampcode management api
2025-12-08 17:42:13 +08:00
Luis Pater
0adc0ee6aa Merge pull request #455 from router-for-me/requestlog
feat(logging): add upstream API request/response capture to streaming logs
2025-12-08 17:40:10 +08:00
hkfires
92f13fc316 feat(logging): add upstream API request/response capture to streaming logs 2025-12-08 17:21:58 +08:00
hkfires
05cfa16e5f refactor(api): simplify request body parsing in ampcode handlers 2025-12-08 14:45:35 +08:00
hkfires
93a6e2d920 feat(api): add comprehensive ampcode management endpoints
Add new REST API endpoints under /v0/management/ampcode for managing
ampcode configuration including upstream URL, API key, localhost
restriction, model mappings, and force model mappings settings.

- Move force-model-mappings from config_basic to config_lists
- Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings
- Support model mapping CRUD with upsert (PATCH) capability
- Add comprehensive test coverage for all ampcode endpoints
2025-12-08 12:03:00 +08:00
Luis Pater
6b49580716 Merge branch 'router-for-me:main' into main 2025-12-08 10:52:39 +08:00
Luis Pater
de77903915 Merge pull request #450 from router-for-me/amp
refactor(config): rename prioritize-model-mappings to force-model-mappings
2025-12-08 10:51:32 +08:00
hkfires
56ed0d8d90 refactor(config): rename prioritize-model-mappings to force-model-mappings 2025-12-08 10:44:39 +08:00
Luis Pater
0d4f32a881 Merge branch 'router-for-me:main' into main 2025-12-08 10:20:08 +08:00
Luis Pater
42e818ce05 Merge pull request #435 from heyhuynhgiabuu/fix/amp-model-mapping-priority
fix: prioritize model mappings over local providers for Amp CLI
2025-12-08 10:17:19 +08:00
Luis Pater
2d4c54ba54 Merge pull request #448 from router-for-me/iflow
Iflow
2025-12-08 09:50:05 +08:00
hkfires
e9eb4db8bb feat(auth): refresh API key during cookie authentication 2025-12-08 09:48:31 +08:00
Luis Pater
d26ed069fa Merge pull request #441 from huynguyen03dev/fix/claude-to-openai-whitespace-text
fix: filter whitespace-only text in Claude to OpenAI translation
2025-12-08 09:43:44 +08:00
huynhgiabuu
afcab5efda feat: add prioritize-model-mappings config option
Add a configuration option to control whether model mappings take
precedence over local API keys for Amp CLI requests.

- Add PrioritizeModelMappings field to AmpCode config struct
- When false (default): Local API keys take precedence (original behavior)
- When true: Model mappings take precedence over local API keys
- Add management API endpoints GET/PUT /prioritize-model-mappings

This allows users who want mapping priority to enable it explicitly
while preserving backward compatibility.

Config example:
  ampcode:
    model-mappings:
      - from: claude-opus-4-5-20251101
        to: gemini-claude-opus-4-5-thinking
    prioritize-model-mappings: true
2025-12-07 22:47:43 +07:00
Ravens2121
1770c491db Merge branch 'router-for-me:main' into main 2025-12-07 22:25:42 +08:00
Ravens2121
a0c6cffb0d fix(kiro):修复 base64 图片格式转换问题 (#10) 2025-12-07 21:55:13 +08:00
Your Name
2bf9e08b31 style(kiro): convert Chinese comments to English in base64 image handling 2025-12-07 21:51:24 +08:00
Ravens2121
f56bfaa689 Merge branch 'router-for-me:main' into main 2025-12-07 21:35:16 +08:00
Your Name
5d716dc796 fix(kiro): 修复 base64 图片格式转换问题 2025-12-07 21:34:44 +08:00
Luis Pater
f81ff16022 Merge pull request #8 from Ravens2121/main
feat: 添加Kiro渠道图片支持功能,借鉴justlovemaki/AIClient-2-API实现
2025-12-07 20:18:31 +08:00
Luis Pater
6cf1d8a947 Merge pull request #444 from router-for-me/agry
feat(registry): add explicit thinking support config for antigravity models
2025-12-07 19:38:43 +08:00
hkfires
a174d015f2 feat(openai): handle thinking.budget_tokens from Anthropic-style requests 2025-12-07 19:14:05 +08:00
hkfires
9c09128e00 feat(registry): add explicit thinking support config for antigravity models 2025-12-07 19:12:55 +08:00
Your Name
68cbe20664 feat: 添加Kiro渠道图片支持功能,借鉴justlovemaki/AIClient-2-API实现 2025-12-07 17:56:06 +08:00
huynguyen03.dev
549c0c2c5a fix: filter whitespace-only text content in Claude to OpenAI translation
Remove redundant existence check since TrimSpace handles empty strings
2025-12-07 16:08:12 +07:00
huynguyen03.dev
f092801b61 fix: filter whitespace-only text in Claude to OpenAI translation
Skip text content blocks that are empty or contain only whitespace
when translating Claude messages to OpenAI format. This fixes GLM-4.6
and other strict OpenAI-compatible providers that reject empty text
with error 'text cannot be empty'.
2025-12-07 15:39:58 +07:00
Luis Pater
15353a6b6a Merge branch 'router-for-me:main' into main 2025-12-07 13:37:43 +08:00
Luis Pater
1b638b3629 Merge pull request #432 from huynguyen03dev/fix/amp-gemini-model-mapping
fix(amp): pass mapped model to gemini bridge via context
2025-12-07 13:33:28 +08:00
Luis Pater
6f5f81753d Merge pull request #439 from router-for-me/log
feat(logging): add version info to request log output
2025-12-07 13:31:06 +08:00
Luis Pater
76af454034 **feat(antigravity): enhance handling of "thinking" content and refine Claude model response processing** 2025-12-07 13:19:12 +08:00
hkfires
e54d2f6b2a feat(logging): add version info to request log output 2025-12-07 12:49:14 +08:00
huynguyen03.dev
bfc738b76a refactor: remove duplicate provider check in gemini v1beta1 route
Simplifies routing logic by delegating all provider/mapping/proxy
decisions to FallbackHandler. Previously, the route checked for
provider/mapping availability before calling the handler, then
FallbackHandler performed the same checks again.

Changes:
- Remove model extraction and provider checking from route (lines 182-201)
- Route now only checks if request is POST with /models/ path
- FallbackHandler handles provider -> mapping -> proxy fallback
- Remove unused internal/util import

Benefits:
- Eliminates duplicate checks (addresses PR review feedback #2)
- Centralizes all provider/mapping logic in FallbackHandler
- Reduces routing code by ~20 lines
- Aligns with how other /api/provider routes work

Performance: No impact (checks still happen once in FallbackHandler)
2025-12-07 10:54:58 +07:00
huynguyen03.dev
396899a530 refactor: improve gemini bridge testability and code quality
- Change createGeminiBridgeHandler to accept gin.HandlerFunc instead of *gemini.GeminiAPIHandler
  This allows tests to inject mock handlers instead of duplicating bridge logic
- Replace magic number 8 with len(modelsPrefix) for better maintainability
- Remove redundant test case that doesn't test edge case in production
- Update routes.go to pass geminiHandlers.GeminiHandler directly

Addresses PR review feedback on test architecture and code clarity.

Amp-Thread-ID: https://ampcode.com/threads/T-1ae2c691-e434-4b99-a49a-10cabd3544db
2025-12-07 10:15:42 +07:00
Luis Pater
04f0070a80 Merge branch 'router-for-me:main' into main 2025-12-07 02:38:32 +08:00
Luis Pater
f383840cf9 fix(antigravity): update toolNode role from "tool" to "user" in chat completions 2025-12-07 02:37:46 +08:00
Luis Pater
239fc4a8c4 Merge branch 'router-for-me:main' into main 2025-12-07 01:58:03 +08:00
Luis Pater
fd29ab418a Fixed: #424
**feat(antigravity): add support for maxOutputTokens and refine Claude model handling**
2025-12-07 01:55:57 +08:00
Luis Pater
df91408919 Merge branch 'router-for-me:main' into main 2025-12-07 01:49:20 +08:00
Luis Pater
7a628426dc Fixed: #433
refactor(translator): normalize finish reason casing across all OpenAI response handlers
2025-12-07 01:48:24 +08:00
Luis Pater
aa6c7facab Merge pull request #5 from router-for-me/plus
Sync
2025-12-07 01:23:53 +08:00
Luis Pater
8ba4c7c7ed Merge branch 'main' into plus 2025-12-07 01:23:36 +08:00
Luis Pater
56b4d7a76e docs(readme): add ProxyPal CLIProxyAPI GUI to project list 2025-12-07 01:13:30 +08:00
Luis Pater
b211c3546d Merge pull request #429 from heyhuynhgiabuu/feature/add-proxypal
docs: add ProxyPal to 'Who is with us?' section
2025-12-07 01:10:44 +08:00
huynguyen03.dev
edc654edf9 refactor: simplify provider check logic in amp routes
Amp-Thread-ID: https://ampcode.com/threads/T-a18fd71c-32ce-4c29-93d7-09f082740e51
2025-12-06 22:07:40 +07:00
huynguyen03.dev
08586334af fix(amp): pass mapped model to gemini bridge via context
Gemini handler extracts model from URL path, not JSON body, so
rewriting the request body alone wasn't sufficient for model mapping.

- Add MappedModelContextKey constant for context passing
- Update routes.go to use NewFallbackHandlerWithMapper
- Add check for valid mapping before routing to local handler
- Add tests for gemini bridge model mapping
2025-12-06 18:59:44 +07:00
Luis Pater
a4804b358f Merge branch 'router-for-me:main' into main 2025-12-06 15:15:27 +08:00
Luis Pater
7ea14479fb Merge pull request #428 from router-for-me/amp
feat(amp): add response rewriter for model name substitution in responses
2025-12-06 15:14:10 +08:00
hkfires
54af96d321 fix(amp): restore request body before fallback handler execution 2025-12-06 15:09:25 +08:00
hkfires
22579155c5 refactor(amp): consolidate and simplify model mapping debug logs 2025-12-06 14:54:38 +08:00
Huynh Gia Buu
c04c3832a4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-06 13:48:08 +07:00
huynhgiabuu
5ffbd54755 docs: add ProxyPal to 'Who is with us?' section 2025-12-06 13:45:49 +07:00
hkfires
5d12d4ce33 feat(amp): add response rewriter for model name substitution in responses 2025-12-06 14:15:44 +08:00
Luis Pater
b73e53d6c4 docs: update README to fix formatting for GitHub Copilot and Kiro support sections 2025-12-06 12:14:23 +08:00
Luis Pater
b06463c6d9 docs: update README to include Kiro (AWS CodeWhisperer) support integration details 2025-12-06 12:04:36 +08:00
Luis Pater
5eb8453e91 Merge pull request #3 from fuko2935/feature/kiro-integration
Feature/kiro integration
2025-12-06 12:00:48 +08:00
Luis Pater
f77c22e6ff Merge branch 'main' into feature/kiro-integration 2025-12-06 11:52:59 +08:00
Mansi
df83ba877f refactor: replace custom stringContains with strings.Contains
- Remove custom stringContains and findSubstring helper functions
- Use standard library strings.Contains for better maintainability
- No functional change, just cleaner code

Addresses Gemini Code Assist review feedback
2025-12-05 23:16:48 +03:00
Mansi
9583f6b1c5 refactor: extract setKiroIncognitoMode helper to reduce code duplication
- Added setKiroIncognitoMode() helper function to handle Kiro auth incognito mode setting
- Replaced 3 duplicate code blocks (21 lines) with single function calls (3 lines)
- Kiro auth defaults to incognito mode for multi-account support
- Users can override with --incognito or --no-incognito flags

This addresses the code duplication noted in PR #1 review.
2025-12-05 23:06:07 +03:00
Mansi
02d8a1cfec feat(kiro): add AWS Builder ID authentication support
- Add --kiro-aws-login flag for AWS Builder ID device code flow
- Add DoKiroAWSLogin function for AWS SSO OIDC authentication
- Complete Kiro integration with AWS, Google OAuth, and social auth
- Add kiro executor, translator, and SDK components
- Update browser support for Kiro authentication flows
2025-12-05 22:46:24 +03:00
Luis Pater
92f033dec0 Merge branch 'router-for-me:main' into main 2025-12-06 01:33:34 +08:00
Luis Pater
0ebabf5152 feat(antigravity): add FetchAntigravityProjectID function and integrate project ID retrieval 2025-12-06 01:32:12 +08:00
Luis Pater
4b01ecba2e Merge branch 'router-for-me:main' into main 2025-12-06 01:12:43 +08:00
Luis Pater
d7564173dd fix(antigravity): restore production base URL in the executor 2025-12-06 01:11:37 +08:00
Luis Pater
f241124599 Merge branch 'router-for-me:main' into main 2025-12-06 00:43:02 +08:00
Luis Pater
c44c46dd80 Fixed: #421
feat(antigravity): implement project ID retrieval and integration in payload processing
2025-12-06 00:40:55 +08:00
86 changed files with 9851 additions and 533 deletions

2
.gitignore vendored
View File

@@ -1,5 +1,6 @@
# Binaries
cli-proxy-api
cliproxy
*.exe
# Configuration
@@ -31,6 +32,7 @@ GEMINI.md
.vscode/*
.claude/*
.serena/*
.mcp/cache/
# macOS
.DS_Store

View File

@@ -10,7 +10,8 @@ The Plus release stays in lockstep with the mainline features.
## Differences from the Mainline
- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)
## Contributing

View File

@@ -10,7 +10,8 @@
## 与主线版本版本差异
- 新增 GitHub Copilot 支持OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 GitHub Copilot 支持OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供
## 贡献

View File

@@ -47,6 +47,19 @@ func init() {
buildinfo.BuildDate = BuildDate
}
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
// Kiro defaults to incognito mode for multi-account support.
// Users can explicitly override with --incognito or --no-incognito flags.
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
if useIncognito {
cfg.IncognitoBrowser = true
} else if noIncognito {
cfg.IncognitoBrowser = false
} else {
cfg.IncognitoBrowser = true // Kiro default
}
}
// main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode).
@@ -62,11 +75,17 @@ func main() {
var iflowCookie bool
var noBrowser bool
var antigravityLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
var kiroImport bool
var githubCopilotLogin bool
var projectID string
var vertexImport string
var configPath string
var password string
var noIncognito bool
var useIncognito bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -76,7 +95,13 @@ func main() {
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
@@ -141,7 +166,8 @@ func main() {
wd, err := os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
log.Errorf("failed to get working directory: %v", err)
return
}
// Load environment variables from .env if present.
@@ -235,13 +261,15 @@ func main() {
})
cancel()
if err != nil {
log.Fatalf("failed to initialize postgres token store: %v", err)
log.Errorf("failed to initialize postgres token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap)
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = pgStoreInst.ConfigPath()
@@ -264,7 +292,8 @@ func main() {
if strings.Contains(resolvedEndpoint, "://") {
parsed, errParse := url.Parse(resolvedEndpoint)
if errParse != nil {
log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
return
}
switch strings.ToLower(parsed.Scheme) {
case "http":
@@ -272,10 +301,12 @@ func main() {
case "https":
useSSL = true
default:
log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
return
}
if parsed.Host == "" {
log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint)
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
return
}
resolvedEndpoint = parsed.Host
if parsed.Path != "" && parsed.Path != "/" {
@@ -294,13 +325,15 @@ func main() {
}
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
if err != nil {
log.Fatalf("failed to initialize object token store: %v", err)
log.Errorf("failed to initialize object token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap)
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = objectStoreInst.ConfigPath()
@@ -325,7 +358,8 @@ func main() {
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Fatalf("failed to prepare git token store: %v", errRepo)
log.Errorf("failed to prepare git token store: %v", errRepo)
return
}
configFilePath = gitStoreInst.ConfigPath()
if configFilePath == "" {
@@ -334,17 +368,21 @@ func main() {
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
examplePath := filepath.Join(wd, "config.example.yaml")
if _, errExample := os.Stat(examplePath); errExample != nil {
log.Fatalf("failed to find template config file: %v", errExample)
log.Errorf("failed to find template config file: %v", errExample)
return
}
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
log.Fatalf("failed to bootstrap git-backed config: %v", errCopy)
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
return
}
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
log.Fatalf("failed to commit initial git-backed config: %v", errCommit)
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
return
}
log.Infof("git-backed config initialized from template: %s", configFilePath)
} else if statErr != nil {
log.Fatalf("failed to inspect git-backed config: %v", statErr)
log.Errorf("failed to inspect git-backed config: %v", statErr)
return
}
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
if err == nil {
@@ -357,13 +395,15 @@ func main() {
} else {
wd, err = os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
log.Errorf("failed to get working directory: %v", err)
return
}
configFilePath = filepath.Join(wd, "config.yaml")
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
}
if err != nil {
log.Fatalf("failed to load config: %v", err)
log.Errorf("failed to load config: %v", err)
return
}
if cfg == nil {
cfg = &config.Config{}
@@ -393,7 +433,8 @@ func main() {
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
log.Fatalf("failed to configure log output: %v", err)
log.Errorf("failed to configure log output: %v", err)
return
}
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
@@ -402,7 +443,8 @@ func main() {
util.SetLogLevel(cfg)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir)
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return
} else {
cfg.AuthDir = resolvedAuthDir
}
@@ -453,6 +495,26 @@ func main() {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
// and don't share config with StartService (which is in the else branch)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
cmd.DoKiroLogin(cfg, options)
} else if kiroGoogleLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
cmd.DoKiroGoogleLogin(cfg, options)
} else if kiroAWSLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
cmd.DoKiroAWSLogin(cfg, options)
} else if kiroImport {
cmd.DoKiroImport(cfg, options)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {

View File

@@ -1,3 +1,7 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
# Server port
port: 8317
@@ -32,6 +36,11 @@ api-keys:
# Enable debug logging
debug: false
# Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support)
incognito-browser: true
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
@@ -99,6 +108,16 @@ ws-auth: false
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
#kiro:
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
# agent-task-type: "" # optional: "vibe" or empty (API default)
# - access-token: "aoaAAAAA..." # or provide tokens directly
# refresh-token: "aorAAAAA..."
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
@@ -134,6 +153,8 @@ ws-auth: false
# upstream-api-key: ""
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
# restrict-management-to-localhost: true
# # Force model mappings to run before checking local API keys (default: false)
# force-model-mappings: false
# # Amp Model Mappings
# # Route unavailable Amp models to alternative models available in your local proxy.
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)

3
go.mod
View File

@@ -13,14 +13,15 @@ require (
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/sirupsen/logrus v1.9.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.7.0
golang.org/x/crypto v0.43.0
golang.org/x/net v0.46.0
golang.org/x/oauth2 v0.30.0
golang.org/x/term v0.36.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)

5
go.sum
View File

@@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
@@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=

View File

@@ -36,9 +36,32 @@ import (
)
var (
oauthStatus = make(map[string]string)
oauthStatus = make(map[string]string)
oauthStatusMutex sync.RWMutex
)
// getOAuthStatus safely retrieves an OAuth status
func getOAuthStatus(key string) (string, bool) {
oauthStatusMutex.RLock()
defer oauthStatusMutex.RUnlock()
status, ok := oauthStatus[key]
return status, ok
}
// setOAuthStatus safely sets an OAuth status
func setOAuthStatus(key string, status string) {
oauthStatusMutex.Lock()
defer oauthStatusMutex.Unlock()
oauthStatus[key] = status
}
// deleteOAuthStatus safely deletes an OAuth status
func deleteOAuthStatus(key string) {
oauthStatusMutex.Lock()
defer oauthStatusMutex.Unlock()
delete(oauthStatus, key)
}
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
const (
@@ -713,14 +736,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Generate PKCE codes
pkceCodes, err := claude.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -730,7 +755,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Generate authorization URL (then override redirect_uri to reuse server port)
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
@@ -760,7 +786,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
oauthStatus[state] = "Timeout waiting for OAuth callback"
setOAuthStatus(state, "Timeout waiting for OAuth callback")
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
data, errRead := os.ReadFile(path)
@@ -785,13 +811,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
if errStr := resultMap["error"]; errStr != "" {
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(claude.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad request"
setOAuthStatus(state, "Bad request")
return
}
if resultMap["state"] != state {
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
log.Error(claude.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "State code error"
setOAuthStatus(state, "State code error")
return
}
@@ -824,7 +850,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
return
}
defer func() {
@@ -835,7 +861,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
return
}
var tResp struct {
@@ -848,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
oauthStatus[state] = "Failed to parse token response"
setOAuthStatus(state, "Failed to parse token response")
return
}
bundle := &claude.ClaudeAuthBundle{
@@ -872,8 +898,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
log.Errorf("Failed to save authentication tokens: %v", errSave)
setOAuthStatus(state, "Failed to save authentication tokens")
return
}
@@ -882,10 +908,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use Claude services through this CLI")
delete(oauthStatus, state)
deleteOAuthStatus(state)
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -944,7 +970,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out"
setOAuthStatus(state, "OAuth flow timed out")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
@@ -953,13 +979,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
return
}
authCode = m["code"]
if authCode == "" {
log.Errorf("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found"
setOAuthStatus(state, "Authentication failed: code not found")
return
}
break
@@ -971,7 +997,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
token, err := conf.Exchange(ctx, authCode)
if err != nil {
log.Errorf("Failed to exchange token: %v", err)
oauthStatus[state] = "Failed to exchange token"
setOAuthStatus(state, "Failed to exchange token")
return
}
@@ -982,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errNewRequest != nil {
log.Errorf("Could not get user info: %v", errNewRequest)
oauthStatus[state] = "Could not get user info"
setOAuthStatus(state, "Could not get user info")
return
}
req.Header.Set("Content-Type", "application/json")
@@ -991,7 +1017,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
resp, errDo := authHTTPClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute request: %v", errDo)
oauthStatus[state] = "Failed to execute request"
setOAuthStatus(state, "Failed to execute request")
return
}
defer func() {
@@ -1003,7 +1029,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
return
}
@@ -1012,7 +1038,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
fmt.Printf("Authenticated user email: %s\n", email)
} else {
fmt.Println("Failed to get user email from token")
oauthStatus[state] = "Failed to get user email from token"
setOAuthStatus(state, "Failed to get user email from token")
}
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
@@ -1020,7 +1046,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
jsonData, _ := json.Marshal(token)
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
oauthStatus[state] = "Failed to unmarshal token"
setOAuthStatus(state, "Failed to unmarshal token")
return
}
@@ -1045,8 +1071,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
oauthStatus[state] = "Failed to get authenticated client"
log.Errorf("failed to get authenticated client: %v", errGetClient)
setOAuthStatus(state, "Failed to get authenticated client")
return
}
fmt.Println("Authentication successful.")
@@ -1056,12 +1082,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
return
}
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
oauthStatus[state] = "Failed to verify Cloud AI API status"
setOAuthStatus(state, "Failed to verify Cloud AI API status")
return
}
ts.ProjectID = strings.Join(projects, ",")
@@ -1069,26 +1095,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Onboarding did not return a project ID")
oauthStatus[state] = "Failed to resolve project ID"
setOAuthStatus(state, "Failed to resolve project ID")
return
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
oauthStatus[state] = "Failed to verify Cloud AI API status"
setOAuthStatus(state, "Failed to verify Cloud AI API status")
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project")
oauthStatus[state] = "Cloud AI API not enabled"
setOAuthStatus(state, "Cloud AI API not enabled")
return
}
}
@@ -1110,16 +1136,16 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
log.Errorf("Failed to save token to file: %v", errSave)
setOAuthStatus(state, "Failed to save token to file")
return
}
delete(oauthStatus, state)
deleteOAuthStatus(state)
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -1131,14 +1157,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
// Generate PKCE codes
pkceCodes, err := codex.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -1148,7 +1176,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
// Generate authorization URL
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
@@ -1180,7 +1209,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "Timeout waiting for OAuth callback"
setOAuthStatus(state, "Timeout waiting for OAuth callback")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
@@ -1190,12 +1219,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
if errStr := m["error"]; errStr != "" {
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(codex.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad Request"
setOAuthStatus(state, "Bad Request")
return
}
if m["state"] != state {
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
oauthStatus[state] = "State code error"
setOAuthStatus(state, "State code error")
log.Error(codex.GetUserFriendlyMessage(authErr))
return
}
@@ -1226,14 +1255,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return
}
@@ -1244,7 +1273,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
oauthStatus[state] = "Failed to parse token response"
setOAuthStatus(state, "Failed to parse token response")
log.Errorf("failed to parse token response: %v", errU)
return
}
@@ -1282,8 +1311,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
setOAuthStatus(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
@@ -1291,10 +1320,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use Codex services through this CLI")
delete(oauthStatus, state)
deleteOAuthStatus(state)
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -1318,7 +1347,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Fatalf("Failed to generate state parameter: %v", errState)
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -1360,7 +1390,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out"
setOAuthStatus(state, "OAuth flow timed out")
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
@@ -1369,18 +1399,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("Authentication failed: state mismatch")
oauthStatus[state] = "Authentication failed: state mismatch"
setOAuthStatus(state, "Authentication failed: state mismatch")
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found"
setOAuthStatus(state, "Authentication failed: code not found")
return
}
break
@@ -1399,7 +1429,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest)
oauthStatus[state] = "Failed to build token request"
setOAuthStatus(state, "Failed to build token request")
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -1407,7 +1437,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo)
oauthStatus[state] = "Failed to exchange token"
setOAuthStatus(state, "Failed to exchange token")
return
}
defer func() {
@@ -1419,7 +1449,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
return
}
@@ -1431,7 +1461,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
}
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode)
oauthStatus[state] = "Failed to parse token response"
setOAuthStatus(state, "Failed to parse token response")
return
}
@@ -1440,7 +1470,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errInfoReq != nil {
log.Errorf("Failed to build user info request: %v", errInfoReq)
oauthStatus[state] = "Failed to build user info request"
setOAuthStatus(state, "Failed to build user info request")
return
}
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
@@ -1448,7 +1478,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo)
oauthStatus[state] = "Failed to execute user info request"
setOAuthStatus(state, "Failed to execute user info request")
return
}
defer func() {
@@ -1467,11 +1497,22 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
} else {
bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
return
}
}
projectID := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" {
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
}
}
now := time.Now()
metadata := map[string]any{
"type": "antigravity",
@@ -1484,6 +1525,9 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
if email != "" {
metadata["email"] = email
}
if projectID != "" {
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email)
label := strings.TrimSpace(email)
@@ -1500,17 +1544,20 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
log.Errorf("Failed to save token to file: %v", errSave)
setOAuthStatus(state, "Failed to save token to file")
return
}
delete(oauthStatus, state)
deleteOAuthStatus(state)
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
}
fmt.Println("You can now use Antigravity services through this CLI")
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -1526,7 +1573,8 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
// Generate authorization URL
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
authURL := deviceFlow.VerificationURIComplete
@@ -1535,7 +1583,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
fmt.Println("Waiting for authentication...")
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if errPollForToken != nil {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPollForToken)
return
}
@@ -1553,17 +1601,17 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
log.Errorf("Failed to save authentication tokens: %v", errSave)
setOAuthStatus(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use Qwen services through this CLI")
delete(oauthStatus, state)
deleteOAuthStatus(state)
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -1602,7 +1650,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
var resultMap map[string]string
for {
if time.Now().After(deadline) {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Println("Authentication failed: timeout waiting for callback")
return
}
@@ -1615,26 +1663,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
}
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Printf("Authentication failed: %s\n", errStr)
return
}
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Println("Authentication failed: state mismatch")
return
}
code := strings.TrimSpace(resultMap["code"])
if code == "" {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Println("Authentication failed: code missing")
return
}
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
if errExchange != nil {
oauthStatus[state] = "Authentication failed"
setOAuthStatus(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errExchange)
return
}
@@ -1656,8 +1704,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
setOAuthStatus(state, "Failed to save authentication tokens")
return
}
@@ -1666,10 +1714,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use iFlow services through this CLI")
delete(oauthStatus, state)
deleteOAuthStatus(state)
}()
oauthStatus[state] = ""
setOAuthStatus(state, "")
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
@@ -2086,6 +2134,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil
@@ -2093,7 +2142,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
func (h *Handler) GetAuthStatus(c *gin.Context) {
state := c.Query("state")
if err, ok := oauthStatus[state]; ok {
if err, ok := getOAuthStatus(state); ok {
if err != "" {
c.JSON(200, gin.H{"status": "error", "error": err})
} else {
@@ -2103,5 +2152,5 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
} else {
c.JSON(200, gin.H{"status": "ok"})
}
delete(oauthStatus, state)
deleteOAuthStatus(state)
}

View File

@@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
}
entry.Models = normalized
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
return
}
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
}
// GetAmpUpstreamURL returns the ampcode upstream URL.
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-url": ""})
return
}
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
}
// PutAmpUpstreamURL updates the ampcode upstream URL.
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
h.cfg.AmpCode.UpstreamURL = ""
h.persist(c)
}
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-key": ""})
return
}
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
}
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
h.cfg.AmpCode.UpstreamAPIKey = ""
h.persist(c)
}
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
return
}
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
}
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
}
// GetAmpModelMappings returns the ampcode model mappings.
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
return
}
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
}
// PutAmpModelMappings replaces all ampcode model mappings.
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
h.cfg.AmpCode.ModelMappings = body.Value
h.persist(c)
}
// PatchAmpModelMappings adds or updates model mappings.
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, m := range h.cfg.AmpCode.ModelMappings {
existing[strings.TrimSpace(m.From)] = i
}
for _, newMapping := range body.Value {
from := strings.TrimSpace(newMapping.From)
if idx, ok := existing[from]; ok {
h.cfg.AmpCode.ModelMappings[idx] = newMapping
} else {
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
}
}
h.persist(c)
}
// DeleteAmpModelMappings removes specified model mappings by "from" field.
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
h.cfg.AmpCode.ModelMappings = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, from := range body.Value {
toRemove[strings.TrimSpace(from)] = true
}
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
for _, m := range h.cfg.AmpCode.ModelMappings {
if !toRemove[strings.TrimSpace(m.From)] {
newMappings = append(newMappings, m)
}
}
h.cfg.AmpCode.ModelMappings = newMappings
h.persist(c)
}
// GetAmpForceModelMappings returns whether model mappings are forced.
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"force-model-mappings": false})
return
}
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
}
// PutAmpForceModelMappings updates the force model mappings setting.
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}

View File

@@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
Value *bool `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
var m map[string]any
if err2 := c.ShouldBindJSON(&m); err2 == nil {
for _, v := range m {
if b, ok := v.(bool); ok {
set(b)
h.persist(c)
return
}
}
}
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}

View File

@@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
w.streamDone = nil
}
// Write API Request and Response to the streaming log before closing
if w.streamWriter != nil {
apiRequest := w.extractAPIRequest(c)
if len(apiRequest) > 0 {
_ = w.streamWriter.WriteAPIRequest(apiRequest)
}
apiResponse := w.extractAPIResponse(c)
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err

View File

@@ -100,6 +100,16 @@ func (m *AmpModule) Name() string {
return "amp-routing"
}
// forceModelMappings returns whether model mappings should take precedence over local API keys
func (m *AmpModule) forceModelMappings() bool {
m.configMu.RLock()
defer m.configMu.RUnlock()
if m.lastConfig == nil {
return false
}
return m.lastConfig.ForceModelMappings
}
// Register sets up Amp routes if configured.
// This implements the RouteModuleV2 interface with Context.
// Routes are registered only once via sync.Once for idempotent behavior.

View File

@@ -2,7 +2,6 @@ package amp
import (
"bytes"
"encoding/json"
"io"
"net/http/httputil"
"strings"
@@ -11,6 +10,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// AmpRouteType represents the type of routing decision made for an Amp request
@@ -27,6 +28,9 @@ const (
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
)
// MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model"
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{
@@ -73,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
forceModelMappings func() bool
}
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
getProxy: getProxy,
forceModelMappings: func() bool { return false },
}
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }
}
return &FallbackHandler{
getProxy: getProxy,
modelMapper: mapper,
getProxy: getProxy,
modelMapper: mapper,
forceModelMappings: forceModelMappings,
}
}
@@ -126,32 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Normalize model (handles Gemini thinking suffixes)
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
// Check if we have providers for this model
providers := util.GetProviderName(normalizedModel)
// Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel
usedMapping := false
var providers []string
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - rewrite the model in request body
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
resolvedModel = mappedModel
usedMapping = true
// Get providers for the mapped model
providers = util.GetProviderName(mappedModel)
// Continue to handler with remapped model
goto handleRequest
// Mapping found - check if we have a provider for the mapped model
mappedProviders := util.GetProviderName(mappedModel)
if len(mappedProviders) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
// No mapping found - check if we have a proxy for fallback
// If no mapping applied, check for local providers
if !usedMapping {
providers = util.GetProviderName(normalizedModel)
}
} else {
// DEFAULT MODE: Check local providers first, then mappings as fallback
providers = util.GetProviderName(normalizedModel)
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - check if we have a provider for the mapped model
mappedProviders := util.GetProviderName(mappedModel)
if len(mappedProviders) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
}
}
// If no providers available, fallback to ampcode.com
if len(providers) == 0 {
proxy := fh.getProxy()
if proxy != nil {
// Log: Forwarding to ampcode.com (uses Amp credits)
@@ -169,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
}
handleRequest:
// Log the routing decision
providerName := ""
if len(providers) > 0 {
@@ -179,59 +220,62 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if usedMapping {
// Log: Model was mapped to another model
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
rewriter.Flush()
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
} else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
// Providers available or no proxy for fallback, restore body and use normal handler
// Filter Anthropic-Beta header to remove features requiring special subscription
// This is needed when using local providers (bypassing the Amp proxy)
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
if filtered != "" {
c.Request.Header.Set("Anthropic-Beta", filtered)
} else {
c.Request.Header.Del("Anthropic-Beta")
}
}
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
}
// rewriteModelInBody replaces the model name in a JSON request body
func rewriteModelInBody(body []byte, newModel string) []byte {
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err != nil {
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
// This is needed when using local providers (bypassing the Amp proxy)
func filterAntropicBetaHeader(c *gin.Context) {
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
c.Request.Header.Set("Anthropic-Beta", filtered)
} else {
c.Request.Header.Del("Anthropic-Beta")
}
}
}
// rewriteModelInRequest replaces the model name in a JSON request body
func rewriteModelInRequest(body []byte, newModel string) []byte {
if !gjson.GetBytes(body, "model").Exists() {
return body
}
if _, exists := payload["model"]; exists {
payload["model"] = newModel
newBody, err := json.Marshal(payload)
if err != nil {
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
return body
}
return newBody
result, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
return body
}
return body
return result
}
// extractModelFromRequest attempts to extract the model name from various request formats
func extractModelFromRequest(body []byte, c *gin.Context) string {
// First try to parse from JSON body (OpenAI, Claude, etc.)
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err == nil {
// Check common model field names
if model, ok := payload["model"].(string); ok {
return model
}
// Check common model field names
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
return result.String()
}
// For Gemini requests, model is in the URL path

View File

@@ -4,7 +4,6 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
)
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
@@ -15,16 +14,31 @@ import (
//
// This extracts the model+method from the AMP path and sets it as the :action parameter
// so the standard Gemini handler can process it.
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
//
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Get the full path from the catch-all parameter
path := c.Param("path")
// Extract model:method from AMP CLI path format
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if idx := strings.Index(path, "/models/"); idx >= 0 {
// Extract everything after "/models/"
actionPart := path[idx+8:] // Skip "/models/"
const modelsPrefix = "/models/"
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
// Extract everything after modelsPrefix
actionPart := path[idx+len(modelsPrefix):]
// Check if model was mapped by FallbackHandler
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
// Replace the model part in the action
// actionPart is like "model-name:method"
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:] // ":method"
actionPart = strModel + method
}
}
}
// Set this as the :action parameter that the Gemini handler expects
c.Params = append(c.Params, gin.Param{
@@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
Value: actionPart,
})
// Call the standard Gemini handler
geminiHandler.GeminiHandler(c)
// Call the handler
handler(c)
return
}

View File

@@ -0,0 +1,93 @@
package amp
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
mappedModel string // empty string means no mapping
expectedAction string
}{
{
name: "no_mapping_uses_url_model",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
{
name: "mapped_model_replaces_url_model",
path: "/publishers/google/models/gemini-exp:generateContent",
mappedModel: "gemini-2.0-flash",
expectedAction: "gemini-2.0-flash:generateContent",
},
{
name: "mapping_preserves_method",
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
mappedModel: "gemini-flash",
expectedAction: "gemini-flash:streamGenerateContent",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedAction string
mockGeminiHandler := func(c *gin.Context) {
capturedAction = c.Param("action")
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
}
// Use the actual createGeminiBridgeHandler function
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
r := gin.New()
if tt.mappedModel != "" {
r.Use(func(c *gin.Context) {
c.Set(MappedModelContextKey, tt.mappedModel)
c.Next()
})
}
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
if capturedAction != tt.expectedAction {
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
}
})
}
}
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
gin.SetMode(gin.TestMode)
mockHandler := func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
}
bridgeHandler := createGeminiBridgeHandler(mockHandler)
r := gin.New()
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
}
}

View File

@@ -66,7 +66,6 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
}
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
return targetModel
}

View File

@@ -0,0 +1,98 @@
package amp
import (
"bytes"
"net/http"
"strings"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
}
// NewResponseRewriter creates a new response rewriter for model name substitution
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
}
}
// Write intercepts response writes and buffers them for model name replacement
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
}
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
if rw.originalModel == "" {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
@@ -157,6 +156,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
// Root-level auth routes for CLI login flow
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
@@ -169,30 +169,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// We bridge these to our standard Gemini handler to enable local OAuth.
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
})
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
// Attempt to extract the model name from the AMP-style path
if path := c.Param("path"); strings.Contains(path, "/models/") {
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
modelPart = modelPart[:colonIdx]
}
if modelPart != "" {
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
// Only handle locally when we have a provider; otherwise fall back to proxy
if providers := util.GetProviderName(normalized); len(providers) > 0 {
geminiV1Beta1Handler(c)
return
}
}
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no local provider available -> proxy upstream
@@ -218,7 +210,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper)
}, m.modelMapper, m.forceModelMappings)
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")

View File

@@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Create HTTP server
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: engine,
}
@@ -520,6 +520,26 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
@@ -902,7 +922,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s.handlers.OpenAICompatProviders = providerNames
s.handlers.SetOpenAICompatProviders(providerNames)
s.handlers.UpdateClients(&cfg.SDKConfig)

View File

@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://console.anthropic.com/"
}
// Validate platformURL to prevent XSS - only allow http/https URLs
if !isValidURL(platformURL) {
platformURL = "https://console.anthropic.com/"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
func isValidURL(urlStr string) bool {
urlStr = strings.TrimSpace(urlStr)
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.

View File

@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://platform.openai.com"
}
// Validate platformURL to prevent XSS - only allow http/https URLs
if !isValidURL(platformURL) {
platformURL = "https://platform.openai.com"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
func isValidURL(urlStr string) bool {
urlStr = strings.TrimSpace(urlStr)
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.

View File

@@ -76,7 +76,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -238,7 +239,11 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Start the server in a goroutine.
go func() {
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("ListenAndServe(): %v", err)
log.Errorf("ListenAndServe(): %v", err)
select {
case errChan <- err:
default:
}
}
}()

View File

@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
@@ -28,10 +29,21 @@ const (
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
// Client credentials provided by iFlow for the Code Assist integration.
iFlowOAuthClientID = "10009311001"
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
iFlowOAuthClientID = "10009311001"
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
)
// getIFlowClientSecret returns the iFlow OAuth client secret.
// It first checks the IFLOW_CLIENT_SECRET environment variable,
// falling back to the default value if not set.
func getIFlowClientSecret() string {
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
return secret
}
return defaultIFlowClientSecret
}
// DefaultAPIBaseURL is the canonical chat completions endpoint.
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
}
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Basic "+basic)
@@ -309,17 +321,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string)
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
}
// First, get initial API key information using GET request
// First, get initial API key information using GET request to obtain the name
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
}
// Convert to token data format
// Refresh the API key using POST request
refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err)
}
// Convert to token data format using refreshed key
data := &IFlowTokenData{
APIKey: keyInfo.APIKey,
Expire: keyInfo.ExpireTime,
Email: keyInfo.Name,
APIKey: refreshedKeyInfo.APIKey,
Expire: refreshedKeyInfo.ExpireTime,
Email: refreshedKeyInfo.Name,
Cookie: cookie,
}

301
internal/auth/kiro/aws.go Normal file
View File

@@ -0,0 +1,301 @@
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
// It includes interfaces and implementations for token storage and authentication methods.
package kiro
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
type PKCECodes struct {
// CodeVerifier is the cryptographically random string used to correlate
// the authorization request to the token request
CodeVerifier string `json:"code_verifier"`
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
CodeChallenge string `json:"code_challenge"`
}
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
type KiroTokenData struct {
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"accessToken"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refreshToken"`
// ProfileArn is the AWS CodeWhisperer profile ARN
ProfileArn string `json:"profileArn"`
// ExpiresAt is the timestamp when the token expires
ExpiresAt string `json:"expiresAt"`
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
AuthMethod string `json:"authMethod"`
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
Provider string `json:"provider"`
// ClientID is the OIDC client ID (needed for token refresh)
ClientID string `json:"clientId,omitempty"`
// ClientSecret is the OIDC client secret (needed for token refresh)
ClientSecret string `json:"clientSecret,omitempty"`
// Email is the user's email address (used for file naming)
Email string `json:"email,omitempty"`
}
// KiroAuthBundle aggregates authentication data after OAuth flow completion
type KiroAuthBundle struct {
// TokenData contains the OAuth tokens from the authentication flow
TokenData KiroTokenData `json:"token_data"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
}
// KiroUsageInfo represents usage information from CodeWhisperer API
type KiroUsageInfo struct {
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
SubscriptionTitle string `json:"subscription_title"`
// CurrentUsage is the current credit usage
CurrentUsage float64 `json:"current_usage"`
// UsageLimit is the maximum credit limit
UsageLimit float64 `json:"usage_limit"`
// NextReset is the timestamp of the next usage reset
NextReset string `json:"next_reset"`
}
// KiroModel represents a model available through the CodeWhisperer API
type KiroModel struct {
// ModelID is the unique identifier for the model
ModelID string `json:"modelId"`
// ModelName is the human-readable name
ModelName string `json:"modelName"`
// Description is the model description
Description string `json:"description"`
// RateMultiplier is the credit multiplier for this model
RateMultiplier float64 `json:"rateMultiplier"`
// RateUnit is the unit for rate calculation (e.g., "credit")
RateUnit string `json:"rateUnit"`
// MaxInputTokens is the maximum input token limit
MaxInputTokens int `json:"maxInputTokens,omitempty"`
}
// KiroIDETokenFile is the default path to Kiro IDE's token file
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
func LoadKiroIDEToken() (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
}
var token KiroTokenData
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
}
if token.AccessToken == "" {
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
}
return &token, nil
}
// LoadKiroTokenFromPath loads token data from a custom path.
// This supports multiple accounts by allowing different token files.
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
// Expand ~ to home directory
if len(tokenPath) > 0 && tokenPath[0] == '~' {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenPath = filepath.Join(homeDir, tokenPath[1:])
}
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
}
var token KiroTokenData
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
if token.AccessToken == "" {
return nil, fmt.Errorf("access token is empty in token file")
}
return &token, nil
}
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
// This supports multiple accounts by finding all token files.
func ListKiroTokenFiles() ([]string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
// Check if directory exists
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
return nil, nil // No token files
}
entries, err := os.ReadDir(cacheDir)
if err != nil {
return nil, fmt.Errorf("failed to read cache directory: %w", err)
}
var tokenFiles []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
}
}
return tokenFiles, nil
}
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
// This supports multiple accounts.
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
files, err := ListKiroTokenFiles()
if err != nil {
return nil, err
}
var tokens []*KiroTokenData
for _, file := range files {
token, err := LoadKiroTokenFromPath(file)
if err != nil {
// Skip invalid token files
continue
}
tokens = append(tokens, token)
}
return tokens, nil
}
// JWTClaims represents the claims we care about from a JWT token.
// JWT tokens from Kiro/AWS contain user information in the payload.
type JWTClaims struct {
Email string `json:"email,omitempty"`
Sub string `json:"sub,omitempty"`
PreferredUser string `json:"preferred_username,omitempty"`
Name string `json:"name,omitempty"`
Iss string `json:"iss,omitempty"`
}
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
// JWT tokens typically have format: header.payload.signature
// The payload is base64url-encoded JSON containing user claims.
func ExtractEmailFromJWT(accessToken string) string {
if accessToken == "" {
return ""
}
// JWT format: header.payload.signature
parts := strings.Split(accessToken, ".")
if len(parts) != 3 {
return ""
}
// Decode the payload (second part)
payload := parts[1]
// Add padding if needed (base64url requires padding)
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try RawURLEncoding (no padding)
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
}
var claims JWTClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
// Return email if available
if claims.Email != "" {
return claims.Email
}
// Fallback to preferred_username (some providers use this)
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
return claims.PreferredUser
}
// Fallback to sub if it looks like an email
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
return claims.Sub
}
return ""
}
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
// Replaces special characters with underscores and prevents path traversal attacks.
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
func SanitizeEmailForFilename(email string) string {
if email == "" {
return ""
}
result := email
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
// This prevents encoded characters from bypassing the sanitization.
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
result = strings.ReplaceAll(result, "%2F", "_") // /
result = strings.ReplaceAll(result, "%2f", "_")
result = strings.ReplaceAll(result, "%5C", "_") // \
result = strings.ReplaceAll(result, "%5c", "_")
result = strings.ReplaceAll(result, "%2E", "_") // .
result = strings.ReplaceAll(result, "%2e", "_")
result = strings.ReplaceAll(result, "%00", "_") // null byte
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
// Replace characters that are problematic in filenames
// Keep @ and . in middle but replace other special characters
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
result = strings.ReplaceAll(result, char, "_")
}
// Prevent path traversal: replace leading dots in each path component
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
parts := strings.Split(result, "_")
for i, part := range parts {
for strings.HasPrefix(part, ".") {
part = "_" + part[1:]
}
parts[i] = part
}
result = strings.Join(parts, "_")
return result
}

View File

@@ -0,0 +1,314 @@
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
// This package implements token loading, refresh, and API communication with CodeWhisperer.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
// for their respective API operations.
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
)
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
// It provides methods for loading tokens, refreshing expired tokens,
// and communicating with the CodeWhisperer API.
type KiroAuth struct {
httpClient *http.Client
endpoint string
}
// NewKiroAuth creates a new Kiro authentication service.
// It initializes the HTTP client with proxy settings from the configuration.
//
// Parameters:
// - cfg: The application configuration containing proxy settings
//
// Returns:
// - *KiroAuth: A new Kiro authentication service instance
func NewKiroAuth(cfg *config.Config) *KiroAuth {
return &KiroAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
endpoint: awsKiroEndpoint,
}
}
// LoadTokenFromFile loads token data from a file path.
// This method reads and parses the token file, expanding ~ to the home directory.
//
// Parameters:
// - tokenFile: Path to the token file (supports ~ expansion)
//
// Returns:
// - *KiroTokenData: The parsed token data
// - error: An error if file reading or parsing fails
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
// Expand ~ to home directory
if strings.HasPrefix(tokenFile, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenFile = filepath.Join(home, tokenFile[1:])
}
data, err := os.ReadFile(tokenFile)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var tokenData KiroTokenData
if err := json.Unmarshal(data, &tokenData); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
return &tokenData, nil
}
// IsTokenExpired checks if the token has expired.
// This method parses the expiration timestamp and compares it with the current time.
//
// Parameters:
// - tokenData: The token data to check
//
// Returns:
// - bool: True if the token has expired, false otherwise
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
if tokenData.ExpiresAt == "" {
return true
}
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
// Try alternate format
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
if err != nil {
return true
}
}
return time.Now().After(expiresAt)
}
// makeRequest sends a request to the CodeWhisperer API.
// This is an internal method for making authenticated API calls.
//
// Parameters:
// - ctx: The context for the request
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
// - accessToken: The OAuth access token
// - payload: The request payload
//
// Returns:
// - []byte: The response body
// - error: An error if the request fails
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", target)
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := k.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return body, nil
}
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
// This method fetches the current usage statistics and subscription information.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data containing access token and profile ARN
//
// Returns:
// - *KiroUsageInfo: The usage information
// - error: An error if the request fails
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
if err != nil {
return nil, err
}
var result struct {
SubscriptionInfo struct {
SubscriptionTitle string `json:"subscriptionTitle"`
} `json:"subscriptionInfo"`
UsageBreakdownList []struct {
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
} `json:"usageBreakdownList"`
NextDateReset float64 `json:"nextDateReset"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse usage response: %w", err)
}
usage := &KiroUsageInfo{
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
NextReset: fmt.Sprintf("%v", result.NextDateReset),
}
if len(result.UsageBreakdownList) > 0 {
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
}
return usage, nil
}
// ListAvailableModels retrieves available models from the CodeWhisperer API.
// This method fetches the list of AI models available for the authenticated user.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data containing access token and profile ARN
//
// Returns:
// - []*KiroModel: The list of available models
// - error: An error if the request fails
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
}
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
if err != nil {
return nil, err
}
var result struct {
Models []struct {
ModelID string `json:"modelId"`
ModelName string `json:"modelName"`
Description string `json:"description"`
RateMultiplier float64 `json:"rateMultiplier"`
RateUnit string `json:"rateUnit"`
TokenLimits struct {
MaxInputTokens int `json:"maxInputTokens"`
} `json:"tokenLimits"`
} `json:"models"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse models response: %w", err)
}
models := make([]*KiroModel, 0, len(result.Models))
for _, m := range result.Models {
models = append(models, &KiroModel{
ModelID: m.ModelID,
ModelName: m.ModelName,
Description: m.Description,
RateMultiplier: m.RateMultiplier,
RateUnit: m.RateUnit,
MaxInputTokens: m.TokenLimits.MaxInputTokens,
})
}
return models, nil
}
// CreateTokenStorage creates a new KiroTokenStorage from token data.
// This method converts the token data into a storage structure suitable for persistence.
//
// Parameters:
// - tokenData: The token data to convert
//
// Returns:
// - *KiroTokenStorage: A new token storage instance
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
return &KiroTokenStorage{
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
ProfileArn: tokenData.ProfileArn,
ExpiresAt: tokenData.ExpiresAt,
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
}
}
// ValidateToken checks if the token is valid by making a test API call.
// This method verifies the token by attempting to fetch usage limits.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data to validate
//
// Returns:
// - error: An error if the token is invalid
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
_, err := k.GetUsageLimits(ctx, tokenData)
return err
}
// UpdateTokenStorage updates an existing token storage with new token data.
// This method refreshes the token storage with newly obtained access and refresh tokens.
//
// Parameters:
// - storage: The existing token storage to update
// - tokenData: The new token data to apply
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
storage.AccessToken = tokenData.AccessToken
storage.RefreshToken = tokenData.RefreshToken
storage.ProfileArn = tokenData.ProfileArn
storage.ExpiresAt = tokenData.ExpiresAt
storage.AuthMethod = tokenData.AuthMethod
storage.Provider = tokenData.Provider
storage.LastRefresh = time.Now().Format(time.RFC3339)
}

View File

@@ -0,0 +1,161 @@
package kiro
import (
"encoding/base64"
"encoding/json"
"testing"
)
func TestExtractEmailFromJWT(t *testing.T) {
tests := []struct {
name string
token string
expected string
}{
{
name: "Empty token",
token: "",
expected: "",
},
{
name: "Invalid token format",
token: "not.a.valid.jwt",
expected: "",
},
{
name: "Invalid token - not base64",
token: "xxx.yyy.zzz",
expected: "",
},
{
name: "Valid JWT with email",
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
expected: "test@example.com",
},
{
name: "JWT without email but with preferred_username",
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
expected: "user@domain.com",
},
{
name: "JWT with email-like sub",
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
expected: "another@test.com",
},
{
name: "JWT without any email fields",
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractEmailFromJWT(tt.token)
if result != tt.expected {
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeEmailForFilename(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "Empty email",
email: "",
expected: "",
},
{
name: "Simple email",
email: "user@example.com",
expected: "user@example.com",
},
{
name: "Email with space",
email: "user name@example.com",
expected: "user_name@example.com",
},
{
name: "Email with special chars",
email: "user:name@example.com",
expected: "user_name@example.com",
},
{
name: "Email with multiple special chars",
email: "user/name:test@example.com",
expected: "user_name_test@example.com",
},
{
name: "Path traversal attempt",
email: "../../../etc/passwd",
expected: "_.__.__._etc_passwd",
},
{
name: "Path traversal with backslash",
email: `..\..\..\..\windows\system32`,
expected: "_.__.__.__._windows_system32",
},
{
name: "Null byte injection attempt",
email: "user\x00@evil.com",
expected: "user_@evil.com",
},
// URL-encoded path traversal tests
{
name: "URL-encoded slash",
email: "user%2Fpath@example.com",
expected: "user_path@example.com",
},
{
name: "URL-encoded backslash",
email: "user%5Cpath@example.com",
expected: "user_path@example.com",
},
{
name: "URL-encoded dot",
email: "%2E%2E%2Fetc%2Fpasswd",
expected: "___etc_passwd",
},
{
name: "URL-encoded null",
email: "user%00@evil.com",
expected: "user_@evil.com",
},
{
name: "Double URL-encoding attack",
email: "%252F%252E%252E",
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
},
{
name: "Mixed case URL-encoding",
email: "%2f%2F%5c%5C",
expected: "____",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeEmailForFilename(tt.email)
if result != tt.expected {
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
}
})
}
}
// createTestJWT creates a test JWT token with the given claims
func createTestJWT(claims map[string]any) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payloadBytes, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
return header + "." + payload + "." + signature
}

296
internal/auth/kiro/oauth.go Normal file
View File

@@ -0,0 +1,296 @@
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
package kiro
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"html"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// Kiro auth endpoint
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// Default callback port
defaultCallbackPort = 9876
// Auth timeout
authTimeout = 10 * time.Minute
)
// KiroTokenResponse represents the response from Kiro token endpoint.
type KiroTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
// KiroOAuth handles the OAuth flow for Kiro authentication.
type KiroOAuth struct {
httpClient *http.Client
cfg *config.Config
}
// NewKiroOAuth creates a new Kiro OAuth handler.
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &KiroOAuth{
httpClient: client,
cfg: cfg,
}
}
// generateCodeVerifier generates a random code verifier for PKCE.
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// generateCodeChallenge generates the code challenge from verifier.
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// generateState generates a random state parameter.
func generateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// AuthResult contains the authorization code and state from callback.
type AuthResult struct {
Code string
State string
Error string
}
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
// Try to find an available port - use localhost like Kiro does
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
if err != nil {
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
}
}
port := listener.Addr().(*net.TCPAddr).Port
// Use http scheme for local callback server
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
resultChan := make(chan AuthResult, 1)
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
if errParam != "" {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- AuthResult{Error: errParam}
return
}
if state != expectedState {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- AuthResult{Error: "state mismatch"}
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
resultChan <- AuthResult{Code: code, State: state}
})
server.Handler = mux
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("callback server error: %v", err)
}
}()
go func() {
select {
case <-ctx.Done():
case <-time.After(authTimeout):
case <-resultChan:
}
_ = server.Shutdown(context.Background())
}()
return redirectURI, resultChan, nil
}
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(o.cfg)
return ssoClient.LoginWithBuilderID(ctx)
}
// exchangeCodeForToken exchanges the authorization code for tokens.
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
tokenURL := kiroAuthEndpoint + "/oauth/token"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
}
var tokenResp KiroTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
}, nil
}
// RefreshToken refreshes an expired access token.
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
refreshURL := kiroAuthEndpoint + "/refreshToken"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var tokenResp KiroTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
}, nil
}
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
socialClient := NewSocialAuthClient(o.cfg)
return socialClient.LoginWithGoogle(ctx)
}
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
socialClient := NewSocialAuthClient(o.cfg)
return socialClient.LoginWithGitHub(ctx)
}

View File

@@ -0,0 +1,725 @@
// Package kiro provides custom protocol handler registration for Kiro OAuth.
// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
package kiro
import (
"context"
"fmt"
"html"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
// KiroProtocol is the custom URI scheme used by Kiro
KiroProtocol = "kiro"
// KiroAuthority is the URI authority for authentication callbacks
KiroAuthority = "kiro.kiroAgent"
// KiroAuthPath is the path for successful authentication
KiroAuthPath = "/authenticate-success"
// KiroRedirectURI is the full redirect URI for social auth
KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
// DefaultHandlerPort is the default port for the local callback server
DefaultHandlerPort = 19876
// HandlerTimeout is how long to wait for the OAuth callback
HandlerTimeout = 10 * time.Minute
)
// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
type ProtocolHandler struct {
port int
server *http.Server
listener net.Listener
resultChan chan *AuthCallback
stopChan chan struct{}
mu sync.Mutex
running bool
}
// AuthCallback contains the OAuth callback parameters.
type AuthCallback struct {
Code string
State string
Error string
}
// NewProtocolHandler creates a new protocol handler.
func NewProtocolHandler() *ProtocolHandler {
return &ProtocolHandler{
port: DefaultHandlerPort,
resultChan: make(chan *AuthCallback, 1),
stopChan: make(chan struct{}),
}
}
// Start starts the local callback server that receives redirects from the protocol handler.
func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
h.mu.Lock()
defer h.mu.Unlock()
if h.running {
return h.port, nil
}
// Drain any stale results from previous runs
select {
case <-h.resultChan:
default:
}
// Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
if h.stopChan != nil {
select {
case <-h.stopChan:
// Already closed
default:
close(h.stopChan)
}
}
h.stopChan = make(chan struct{})
// Try ports in known range (must match handler script port range)
var listener net.Listener
var err error
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
for _, port := range portRange {
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err == nil {
break
}
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
}
if listener == nil {
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
}
h.listener = listener
h.port = listener.Addr().(*net.TCPAddr).Port
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", h.handleCallback)
h.server = &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
}
go func() {
if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("kiro protocol handler server error: %v", err)
}
}()
h.running = true
log.Debugf("kiro protocol handler started on port %d", h.port)
// Auto-shutdown after context done, timeout, or explicit stop
// Capture references to prevent race with new Start() calls
currentStopChan := h.stopChan
currentServer := h.server
currentListener := h.listener
go func() {
select {
case <-ctx.Done():
case <-time.After(HandlerTimeout):
case <-currentStopChan:
return // Already stopped, exit goroutine
}
// Only stop if this is still the current server/listener instance
h.mu.Lock()
if h.server == currentServer && h.listener == currentListener {
h.mu.Unlock()
h.Stop()
} else {
h.mu.Unlock()
}
}()
return h.port, nil
}
// Stop stops the callback server.
func (h *ProtocolHandler) Stop() {
h.mu.Lock()
defer h.mu.Unlock()
if !h.running {
return
}
// Signal the auto-shutdown goroutine to exit.
// This select pattern is safe because stopChan is only modified while holding h.mu,
// and we hold the lock here. The select prevents panic from double-close.
select {
case <-h.stopChan:
// Already closed
default:
close(h.stopChan)
}
if h.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = h.server.Shutdown(ctx)
}
h.running = false
log.Debug("kiro protocol handler stopped")
}
// WaitForCallback waits for the OAuth callback and returns the result.
func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(HandlerTimeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
case result := <-h.resultChan:
return result, nil
}
}
// GetPort returns the port the handler is listening on.
func (h *ProtocolHandler) GetPort() int {
return h.port
}
// handleCallback processes the OAuth callback from the protocol handler script.
func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
result := &AuthCallback{
Code: code,
State: state,
Error: errParam,
}
// Send result
select {
case h.resultChan <- result:
default:
// Channel full, ignore duplicate callbacks
}
// Send success response
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if errParam != "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>Login Failed</title></head>
<body>
<h1>Login Failed</h1>
<p>Error: %s</p>
<p>You can close this window.</p>
</body>
</html>`, html.EscapeString(errParam))
} else {
fmt.Fprint(w, `<!DOCTYPE html>
<html>
<head><title>Login Successful</title></head>
<body>
<h1>Login Successful!</h1>
<p>You can close this window and return to the terminal.</p>
<script>window.close();</script>
</body>
</html>`)
}
}
// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
func IsProtocolHandlerInstalled() bool {
switch runtime.GOOS {
case "linux":
return isLinuxHandlerInstalled()
case "windows":
return isWindowsHandlerInstalled()
case "darwin":
return isDarwinHandlerInstalled()
default:
return false
}
}
// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
func InstallProtocolHandler(handlerPort int) error {
switch runtime.GOOS {
case "linux":
return installLinuxHandler(handlerPort)
case "windows":
return installWindowsHandler(handlerPort)
case "darwin":
return installDarwinHandler(handlerPort)
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// UninstallProtocolHandler removes the kiro:// protocol handler.
func UninstallProtocolHandler() error {
switch runtime.GOOS {
case "linux":
return uninstallLinuxHandler()
case "windows":
return uninstallWindowsHandler()
case "darwin":
return uninstallDarwinHandler()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// --- Linux Implementation ---
func getLinuxDesktopPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
}
func getLinuxHandlerScriptPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
}
func isLinuxHandlerInstalled() bool {
desktopPath := getLinuxDesktopPath()
_, err := os.Stat(desktopPath)
return err == nil
}
func installLinuxHandler(handlerPort int) error {
// Create directories
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
binDir := filepath.Join(homeDir, ".local", "bin")
appDir := filepath.Join(homeDir, ".local", "share", "applications")
if err := os.MkdirAll(binDir, 0755); err != nil {
return fmt.Errorf("failed to create bin directory: %w", err)
}
if err := os.MkdirAll(appDir, 0755); err != nil {
return fmt.Errorf("failed to create applications directory: %w", err)
}
// Create handler script - tries multiple ports to handle dynamic port allocation
scriptPath := getLinuxHandlerScriptPath()
scriptContent := fmt.Sprintf(`#!/bin/bash
# Kiro OAuth Protocol Handler
# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
URL="$1"
# Check curl availability
if ! command -v curl &> /dev/null; then
echo "Error: curl is required for Kiro OAuth handler" >&2
exit 1
fi
# Extract code and state from URL
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
# Try CLI proxy on multiple possible ports (default + dynamic range)
CLI_OK=0
for PORT in %d %d %d %d %d; do
if [ -n "$ERROR" ]; then
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
fi
done
# If CLI not available, forward to Kiro IDE
if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
/usr/share/kiro/kiro --open-url "$URL" &
fi
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
return fmt.Errorf("failed to write handler script: %w", err)
}
// Create .desktop file
desktopPath := getLinuxDesktopPath()
desktopContent := fmt.Sprintf(`[Desktop Entry]
Name=Kiro OAuth Handler
Comment=Handle kiro:// protocol for CLI Proxy API authentication
Exec=%s %%u
Type=Application
Terminal=false
NoDisplay=true
MimeType=x-scheme-handler/kiro;
Categories=Utility;
`, scriptPath)
if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
return fmt.Errorf("failed to write desktop file: %w", err)
}
// Register handler with xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {
log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
}
// Update desktop database
cmd = exec.Command("update-desktop-database", appDir)
_ = cmd.Run() // Ignore errors, not critical
log.Info("Kiro protocol handler installed for Linux")
return nil
}
func uninstallLinuxHandler() error {
desktopPath := getLinuxDesktopPath()
scriptPath := getLinuxHandlerScriptPath()
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove desktop file: %w", err)
}
if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove handler script: %w", err)
}
log.Info("Kiro protocol handler uninstalled")
return nil
}
// --- Windows Implementation ---
func isWindowsHandlerInstalled() bool {
// Check registry key existence
cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
return cmd.Run() == nil
}
func installWindowsHandler(handlerPort int) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
// Create handler script (PowerShell)
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
if err := os.MkdirAll(scriptDir, 0755); err != nil {
return fmt.Errorf("failed to create script directory: %w", err)
}
scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
param([string]$url)
# Load required assembly for HttpUtility
Add-Type -AssemblyName System.Web
# Parse URL parameters
$uri = [System.Uri]$url
$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
$code = $query["code"]
$state = $query["state"]
$errorParam = $query["error"]
# Try multiple ports (default + dynamic range)
$ports = @(%d, %d, %d, %d, %d)
$success = $false
foreach ($port in $ports) {
if ($success) { break }
$callbackUrl = "http://127.0.0.1:$port/oauth/callback"
try {
if ($errorParam) {
$fullUrl = $callbackUrl + "?error=" + $errorParam
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
$success = $true
} elseif ($code -and $state) {
$fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
$success = $true
}
} catch {
# Try next port
}
}
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
return fmt.Errorf("failed to write handler script: %w", err)
}
// Create batch wrapper
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath)
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
return fmt.Errorf("failed to write batch wrapper: %w", err)
}
// Register in Windows registry
commands := [][]string{
{"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
}
for _, args := range commands {
cmd := exec.Command(args[0], args[1:]...)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to run registry command: %w", err)
}
}
log.Info("Kiro protocol handler installed for Windows")
return nil
}
func uninstallWindowsHandler() error {
// Remove registry keys
cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
if err := cmd.Run(); err != nil {
log.Warnf("failed to remove registry key: %v", err)
}
// Remove scripts
homeDir, _ := os.UserHomeDir()
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
log.Info("Kiro protocol handler uninstalled")
return nil
}
// --- macOS Implementation ---
func getDarwinAppPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
}
func isDarwinHandlerInstalled() bool {
appPath := getDarwinAppPath()
_, err := os.Stat(appPath)
return err == nil
}
func installDarwinHandler(handlerPort int) error {
// Create app bundle structure
appPath := getDarwinAppPath()
contentsPath := filepath.Join(appPath, "Contents")
macOSPath := filepath.Join(contentsPath, "MacOS")
if err := os.MkdirAll(macOSPath, 0755); err != nil {
return fmt.Errorf("failed to create app bundle: %w", err)
}
// Create Info.plist
plistPath := filepath.Join(contentsPath, "Info.plist")
plistContent := `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleIdentifier</key>
<string>com.cliproxyapi.kiro-oauth-handler</string>
<key>CFBundleName</key>
<string>KiroOAuthHandler</string>
<key>CFBundleExecutable</key>
<string>kiro-oauth-handler</string>
<key>CFBundleVersion</key>
<string>1.0</string>
<key>CFBundleURLTypes</key>
<array>
<dict>
<key>CFBundleURLName</key>
<string>Kiro Protocol</string>
<key>CFBundleURLSchemes</key>
<array>
<string>kiro</string>
</array>
</dict>
</array>
<key>LSBackgroundOnly</key>
<true/>
</dict>
</plist>`
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
return fmt.Errorf("failed to write Info.plist: %w", err)
}
// Create executable script - tries multiple ports to handle dynamic port allocation
execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
execContent := fmt.Sprintf(`#!/bin/bash
# Kiro OAuth Protocol Handler for macOS
URL="$1"
# Check curl availability (should always exist on macOS)
if [ ! -x /usr/bin/curl ]; then
echo "Error: curl is required for Kiro OAuth handler" >&2
exit 1
fi
# Extract code and state from URL
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
# Try multiple ports (default + dynamic range)
for PORT in %d %d %d %d %d; do
if [ -n "$ERROR" ]; then
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
fi
done
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
return fmt.Errorf("failed to write executable: %w", err)
}
// Register the app with Launch Services
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
"-f", appPath)
if err := cmd.Run(); err != nil {
log.Warnf("lsregister failed (handler may still work): %v", err)
}
log.Info("Kiro protocol handler installed for macOS")
return nil
}
func uninstallDarwinHandler() error {
appPath := getDarwinAppPath()
// Unregister from Launch Services
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
"-u", appPath)
_ = cmd.Run()
// Remove app bundle
if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove app bundle: %w", err)
}
log.Info("Kiro protocol handler uninstalled")
return nil
}
// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
func ParseKiroURI(rawURI string) (*AuthCallback, error) {
u, err := url.Parse(rawURI)
if err != nil {
return nil, fmt.Errorf("invalid URI: %w", err)
}
if u.Scheme != KiroProtocol {
return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
}
if u.Host != KiroAuthority {
return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
}
query := u.Query()
return &AuthCallback{
Code: query.Get("code"),
State: query.Get("state"),
Error: query.Get("error"),
}, nil
}
// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
func GetHandlerInstructions() string {
switch runtime.GOOS {
case "linux":
return `To manually set up the Kiro protocol handler on Linux:
1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
[Desktop Entry]
Name=Kiro OAuth Handler
Exec=~/.local/bin/kiro-oauth-handler %u
Type=Application
Terminal=false
MimeType=x-scheme-handler/kiro;
2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
#!/bin/bash
URL="$1"
# ... (see generated script for full content)
3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
case "windows":
return `To manually set up the Kiro protocol handler on Windows:
1. Open Registry Editor (regedit.exe)
2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
3. Set default value to: URL:Kiro Protocol
4. Create string value "URL Protocol" with empty data
5. Create subkey: shell\open\command
6. Set default value to: "C:\path\to\handler.bat" "%1"`
case "darwin":
return `To manually set up the Kiro protocol handler on macOS:
1. Create ~/Applications/KiroOAuthHandler.app bundle
2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
3. Create executable in Contents/MacOS/
4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
default:
return "Protocol handler setup is not supported on this platform."
}
}
// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
func SetupProtocolHandlerIfNeeded(handlerPort int) error {
if IsProtocolHandlerInstalled() {
log.Debug("Kiro protocol handler already installed")
return nil
}
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Protocol Handler Setup Required ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
fmt.Println("\nInstalling protocol handler...")
if err := InstallProtocolHandler(handlerPort); err != nil {
fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
fmt.Println("\nManual setup instructions:")
fmt.Println(strings.Repeat("-", 60))
fmt.Println(GetHandlerInstructions())
return err
}
fmt.Println("\n✓ Protocol handler installed successfully!")
return nil
}

View File

@@ -0,0 +1,403 @@
// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
package kiro
import (
"bufio"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"golang.org/x/term"
)
const (
// Kiro AuthService endpoint
kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// OAuth timeout
socialAuthTimeout = 10 * time.Minute
)
// SocialProvider represents the social login provider.
type SocialProvider string
const (
// ProviderGoogle is Google OAuth provider
ProviderGoogle SocialProvider = "Google"
// ProviderGitHub is GitHub OAuth provider
ProviderGitHub SocialProvider = "Github"
// Note: AWS Builder ID is NOT supported by Kiro's auth service.
// It only supports: Google, Github, Cognito
// AWS Builder ID must use device code flow via SSO OIDC.
)
// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
type CreateTokenRequest struct {
Code string `json:"code"`
CodeVerifier string `json:"code_verifier"`
RedirectURI string `json:"redirect_uri"`
InvitationCode string `json:"invitation_code,omitempty"`
}
// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
type SocialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
type RefreshTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// SocialAuthClient handles social authentication with Kiro.
type SocialAuthClient struct {
httpClient *http.Client
cfg *config.Config
protocolHandler *ProtocolHandler
}
// NewSocialAuthClient creates a new social auth client.
func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &SocialAuthClient{
httpClient: client,
cfg: cfg,
protocolHandler: NewProtocolHandler(),
}
}
// generatePKCE generates PKCE code verifier and challenge.
func generatePKCE() (verifier, challenge string, err error) {
// Generate 32 bytes of random data for verifier
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(b)
// Generate SHA256 hash of verifier for challenge
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// generateState generates a random state parameter.
func generateStateParam() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// buildLoginURL constructs the Kiro OAuth login URL.
// The login endpoint expects a GET request with query parameters.
// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
// The prompt=select_account parameter forces the account selection screen even if already logged in.
func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
kiroAuthServiceEndpoint,
provider,
url.QueryEscape(redirectURI),
codeChallenge,
state,
)
}
// createToken exchanges the authorization code for tokens.
func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal token request: %w", err)
}
tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
}
var tokenResp SocialTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// RefreshSocialToken refreshes an expired social auth token.
func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
if err != nil {
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
}
refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var tokenResp SocialTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600 // Default 1 hour
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
}, nil
}
// LoginWithSocial performs OAuth login with Google.
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
providerName := string(provider)
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Setup protocol handler
fmt.Println("\nSetting up authentication...")
// Start the local callback server
handlerPort, err := c.protocolHandler.Start(ctx)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
defer c.protocolHandler.Stop()
// Ensure protocol handler is installed and set as default
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
log.Debugf("kiro: protocol handler setup error: %v", err)
// Continue anyway - user might have set it up manually or select browser manually
} else {
// Force set our handler as default (prevents "Open with" dialog)
forceDefaultProtocolHandler()
}
// Step 2: Generate PKCE codes
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
}
// Step 3: Generate state
state, err := generateStateParam()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 4: Build the login URL (Kiro uses GET request with query params)
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
// Incognito mode enables multi-account support by bypassing cached sessions
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
if !c.cfg.IncognitoBrowser {
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
} else {
log.Debug("kiro: using incognito mode for multi-account support")
}
} else {
browser.SetIncognitoMode(true) // Default to incognito if no config
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Step 5: Open browser for user authentication
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n URL: %s\n\n", authURL)
if err := browser.OpenURL(authURL); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" ⚠ Could not open browser automatically.")
fmt.Println(" Please open the URL above in your browser manually.")
} else {
fmt.Println(" (Browser opened automatically)")
}
fmt.Println("\n Waiting for authentication callback...")
// Step 6: Wait for callback
callback, err := c.protocolHandler.WaitForCallback(ctx)
if err != nil {
return nil, fmt.Errorf("failed to receive callback: %w", err)
}
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
if callback.State != state {
// Log state values for debugging, but don't expose in user-facing error
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
return nil, fmt.Errorf("OAuth state validation failed - please try again")
}
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 7: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: KiroRedirectURI,
}
tokenResp, err := c.createToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
}, nil
}
// LoginWithGoogle performs OAuth login with Google.
func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
return c.LoginWithSocial(ctx, ProviderGoogle)
}
// LoginWithGitHub performs OAuth login with GitHub.
func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
return c.LoginWithSocial(ctx, ProviderGitHub)
}
// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
// This prevents the "Open with" dialog from appearing on Linux.
// On non-Linux platforms, this is a no-op as they use different mechanisms.
func forceDefaultProtocolHandler() {
if runtime.GOOS != "linux" {
return // Non-Linux platforms use different handler mechanisms
}
// Set our handler as default using xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {
log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
}
}
// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
// Returns false in CI/automated environments or when stdin is piped.
func isInteractiveTerminal() bool {
return term.IsTerminal(int(os.Stdin.Fd()))
}

View File

@@ -0,0 +1,527 @@
// Package kiro provides AWS SSO OIDC authentication for Kiro.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// AWS SSO OIDC endpoints
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
// Kiro's start URL for Builder ID
builderIDStartURL = "https://view.awsapps.com/start"
// Polling interval
pollInterval = 5 * time.Second
)
// SSOOIDCClient handles AWS SSO OIDC authentication.
type SSOOIDCClient struct {
httpClient *http.Client
cfg *config.Config
}
// NewSSOOIDCClient creates a new SSO OIDC client.
func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &SSOOIDCClient{
httpClient: client,
cfg: cfg,
}
}
// RegisterClientResponse from AWS SSO OIDC.
type RegisterClientResponse struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
}
// StartDeviceAuthResponse from AWS SSO OIDC.
type StartDeviceAuthResponse struct {
DeviceCode string `json:"deviceCode"`
UserCode string `json:"userCode"`
VerificationURI string `json:"verificationUri"`
VerificationURIComplete string `json:"verificationUriComplete"`
ExpiresIn int `json:"expiresIn"`
Interval int `json:"interval"`
}
// CreateTokenResponse from AWS SSO OIDC.
type CreateTokenResponse struct {
AccessToken string `json:"accessToken"`
TokenType string `json:"tokenType"`
ExpiresIn int `json:"expiresIn"`
RefreshToken string `json:"refreshToken"`
}
// RegisterClient registers a new OIDC client with AWS.
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
// Generate unique client name for each registration to support multiple accounts
clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano())
payload := map[string]interface{}{
"clientName": clientName,
"clientType": "public",
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
}
var result RegisterClientResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// StartDeviceAuthorization starts the device authorization flow.
func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) {
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"startUrl": builderIDStartURL,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
}
var result StartDeviceAuthResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// CreateToken polls for the access token after user authorization.
func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) {
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"deviceCode": deviceCode,
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Check for pending authorization
if resp.StatusCode == http.StatusBadRequest {
var errResp struct {
Error string `json:"error"`
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, fmt.Errorf("authorization_pending")
}
if errResp.Error == "slow_down" {
return nil, fmt.Errorf("slow_down")
}
}
log.Debugf("create token failed: %s", string(respBody))
return nil, fmt.Errorf("create token failed")
}
if resp.StatusCode != http.StatusOK {
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
return &result, nil
}
// RefreshToken refreshes an access token using the refresh token.
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
payload := map[string]string{
"clientId": clientID,
"clientSecret": clientSecret,
"refreshToken": refreshToken,
"grantType": "refresh_token",
}
body, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var result CreateTokenResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, err
}
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
ClientID: clientID,
ClientSecret: clientSecret,
}, nil
}
// LoginWithBuilderID performs the full device code flow for AWS Builder ID.
func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Authentication (AWS Builder ID) ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Register client
fmt.Println("\nRegistering client...")
regResp, err := c.RegisterClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to register client: %w", err)
}
log.Debugf("Client registered: %s", regResp.ClientID)
// Step 2: Start device authorization
fmt.Println("Starting device authorization...")
authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
if err != nil {
return nil, fmt.Errorf("failed to start device auth: %w", err)
}
// Step 3: Show user the verification URL
fmt.Printf("\n")
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf(" Open this URL in your browser:\n")
fmt.Printf(" %s\n", authResp.VerificationURIComplete)
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI)
fmt.Printf(" And enter code: %s\n\n", authResp.UserCode)
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
// Incognito mode enables multi-account support by bypassing cached sessions
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
if !c.cfg.IncognitoBrowser {
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
} else {
log.Debug("kiro: using incognito mode for multi-account support")
}
} else {
browser.SetIncognitoMode(true) // Default to incognito if no config
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Open browser using cross-platform browser package
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" Please open the URL manually in your browser.")
} else {
fmt.Println(" (Browser opened automatically)")
}
// Step 4: Poll for token
fmt.Println("Waiting for authorization...")
interval := pollInterval
if authResp.Interval > 0 {
interval = time.Duration(authResp.Interval) * time.Second
}
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
browser.CloseBrowser() // Cleanup on cancel
return nil, ctx.Err()
case <-time.After(interval):
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "authorization_pending") {
fmt.Print(".")
continue
}
if strings.Contains(errStr, "slow_down") {
interval += 5 * time.Second
continue
}
// Close browser on error before returning
browser.CloseBrowser()
return nil, fmt.Errorf("token creation failed: %w", err)
}
fmt.Println("\n\n✓ Authorization successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Step 5: Get profile ARN from CodeWhisperer API
fmt.Println("Fetching profile information...")
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
// Extract email from JWT access token
email := ExtractEmailFromJWT(tokenResp.AccessToken)
if email != "" {
fmt.Printf(" Logged in as: %s\n", email)
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "builder-id",
Provider: "AWS",
ClientID: regResp.ClientID,
ClientSecret: regResp.ClientSecret,
Email: email,
}, nil
}
}
// Close browser on timeout for better UX
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser on timeout: %v", err)
}
return nil, fmt.Errorf("authorization timed out")
}
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
// Try ListProfiles API first
profileArn := c.tryListProfiles(ctx, accessToken)
if profileArn != "" {
return profileArn
}
// Fallback: Try ListAvailableCustomizations
return c.tryListCustomizations(ctx, accessToken)
}
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
body, err := json.Marshal(payload)
if err != nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListProfiles response: %s", string(respBody))
var result struct {
Profiles []struct {
Arn string `json:"arn"`
} `json:"profiles"`
ProfileArn string `json:"profileArn"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if result.ProfileArn != "" {
return result.ProfileArn
}
if len(result.Profiles) > 0 {
return result.Profiles[0].Arn
}
return ""
}
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
payload := map[string]interface{}{
"origin": "AI_EDITOR",
}
body, err := json.Marshal(payload)
if err != nil {
return ""
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
if err != nil {
return ""
}
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
return ""
}
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
var result struct {
Customizations []struct {
Arn string `json:"arn"`
} `json:"customizations"`
ProfileArn string `json:"profileArn"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return ""
}
if result.ProfileArn != "" {
return result.ProfileArn
}
if len(result.Customizations) > 0 {
return result.Customizations[0].Arn
}
return ""
}

View File

@@ -0,0 +1,72 @@
package kiro
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
// KiroTokenStorage holds the persistent token data for Kiro authentication.
type KiroTokenStorage struct {
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refresh_token"`
// ProfileArn is the AWS CodeWhisperer profile ARN
ProfileArn string `json:"profile_arn"`
// ExpiresAt is the timestamp when the token expires
ExpiresAt string `json:"expires_at"`
// AuthMethod indicates the authentication method used
AuthMethod string `json:"auth_method"`
// Provider indicates the OAuth provider
Provider string `json:"provider"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
}
// SaveTokenToFile persists the token storage to the specified file path.
func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
dir := filepath.Dir(authFilePath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal token storage: %w", err)
}
if err := os.WriteFile(authFilePath, data, 0600); err != nil {
return fmt.Errorf("failed to write token file: %w", err)
}
return nil
}
// LoadFromFile loads token storage from the specified file path.
func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
data, err := os.ReadFile(authFilePath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var storage KiroTokenStorage
if err := json.Unmarshal(data, &storage); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
return &storage, nil
}
// ToTokenData converts storage to KiroTokenData for API use.
func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
return &KiroTokenData{
AccessToken: s.AccessToken,
RefreshToken: s.RefreshToken,
ProfileArn: s.ProfileArn,
ExpiresAt: s.ExpiresAt,
AuthMethod: s.AuthMethod,
Provider: s.Provider,
}
}

View File

@@ -6,14 +6,49 @@ import (
"fmt"
"os/exec"
"runtime"
"strings"
"sync"
pkgbrowser "github.com/pkg/browser"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
)
// incognitoMode controls whether to open URLs in incognito/private mode.
// This is useful for OAuth flows where you want to use a different account.
var incognitoMode bool
// lastBrowserProcess stores the last opened browser process for cleanup
var lastBrowserProcess *exec.Cmd
var browserMutex sync.Mutex
// SetIncognitoMode enables or disables incognito/private browsing mode.
func SetIncognitoMode(enabled bool) {
incognitoMode = enabled
}
// IsIncognitoMode returns whether incognito mode is enabled.
func IsIncognitoMode() bool {
return incognitoMode
}
// CloseBrowser closes the last opened browser process.
func CloseBrowser() error {
browserMutex.Lock()
defer browserMutex.Unlock()
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
return nil
}
err := lastBrowserProcess.Process.Kill()
lastBrowserProcess = nil
return err
}
// OpenURL opens the specified URL in the default web browser.
// It first attempts to use a platform-agnostic library and falls back to
// platform-specific commands if that fails.
// It uses the pkg/browser library which provides robust cross-platform support
// for Windows, macOS, and Linux.
// If incognito mode is enabled, it will open in a private/incognito window.
//
// Parameters:
// - url: The URL to open.
@@ -21,16 +56,22 @@ import (
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func OpenURL(url string) error {
fmt.Printf("Attempting to open URL in browser: %s\n", url)
log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode)
// Try using the open-golang library first
err := open.Run(url)
// If incognito mode is enabled, use platform-specific incognito commands
if incognitoMode {
log.Debug("Using incognito mode")
return openURLIncognito(url)
}
// Use pkg/browser for cross-platform support
err := pkgbrowser.OpenURL(url)
if err == nil {
log.Debug("Successfully opened URL using open-golang library")
log.Debug("Successfully opened URL using pkg/browser library")
return nil
}
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err)
// Fallback to platform-specific commands
return openURLPlatformSpecific(url)
@@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error {
return nil
}
// openURLIncognito opens a URL in incognito/private browsing mode.
// It first tries to detect the default browser and use its incognito flag.
// Falls back to a chain of known browsers if detection fails.
//
// Parameters:
// - url: The URL to open.
//
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func openURLIncognito(url string) error {
// First, try to detect and use the default browser
if cmd := tryDefaultBrowserIncognito(url); cmd != nil {
log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:])
if err := cmd.Start(); err == nil {
storeBrowserProcess(cmd)
log.Debug("Successfully opened URL in default browser's incognito mode")
return nil
}
log.Debugf("Failed to start default browser, trying fallback chain")
}
// Fallback to known browser chain
cmd := tryFallbackBrowsersIncognito(url)
if cmd == nil {
log.Warn("No browser with incognito support found, falling back to normal mode")
return openURLPlatformSpecific(url)
}
log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:])
err := cmd.Start()
if err != nil {
log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err)
return openURLPlatformSpecific(url)
}
storeBrowserProcess(cmd)
log.Debug("Successfully opened URL in incognito/private mode")
return nil
}
// storeBrowserProcess safely stores the browser process for later cleanup.
func storeBrowserProcess(cmd *exec.Cmd) {
browserMutex.Lock()
lastBrowserProcess = cmd
browserMutex.Unlock()
}
// tryDefaultBrowserIncognito attempts to detect the default browser and return
// an exec.Cmd configured with the appropriate incognito flag.
func tryDefaultBrowserIncognito(url string) *exec.Cmd {
switch runtime.GOOS {
case "darwin":
return tryDefaultBrowserMacOS(url)
case "windows":
return tryDefaultBrowserWindows(url)
case "linux":
return tryDefaultBrowserLinux(url)
}
return nil
}
// tryDefaultBrowserMacOS detects the default browser on macOS.
func tryDefaultBrowserMacOS(url string) *exec.Cmd {
// Try to get default browser from Launch Services
out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output()
if err != nil {
return nil
}
output := string(out)
var browserName string
// Parse the output to find the http/https handler
if containsBrowserID(output, "com.google.chrome") {
browserName = "chrome"
} else if containsBrowserID(output, "org.mozilla.firefox") {
browserName = "firefox"
} else if containsBrowserID(output, "com.apple.safari") {
browserName = "safari"
} else if containsBrowserID(output, "com.brave.browser") {
browserName = "brave"
} else if containsBrowserID(output, "com.microsoft.edgemac") {
browserName = "edge"
}
return createMacOSIncognitoCmd(browserName, url)
}
// containsBrowserID checks if the LaunchServices output contains a browser ID.
func containsBrowserID(output, bundleID string) bool {
return strings.Contains(output, bundleID)
}
// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers.
func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
// Try direct path first
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
if _, err := exec.LookPath(chromePath); err == nil {
return exec.Command(chromePath, "--incognito", url)
}
return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url)
case "firefox":
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
case "safari":
// Safari doesn't have CLI incognito, try AppleScript
return tryAppleScriptSafariPrivate(url)
case "brave":
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
case "edge":
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
}
return nil
}
// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript.
func tryAppleScriptSafariPrivate(url string) *exec.Cmd {
// AppleScript to open a new private window in Safari
script := fmt.Sprintf(`
tell application "Safari"
activate
tell application "System Events"
keystroke "n" using {command down, shift down}
delay 0.5
end tell
set URL of document 1 to "%s"
end tell
`, url)
cmd := exec.Command("osascript", "-e", script)
// Test if this approach works by checking if Safari is available
if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil {
log.Debug("Safari not found, AppleScript private window not available")
return nil
}
log.Debug("Attempting Safari private window via AppleScript")
return cmd
}
// tryDefaultBrowserWindows detects the default browser on Windows via registry.
func tryDefaultBrowserWindows(url string) *exec.Cmd {
// Query registry for default browser
out, err := exec.Command("reg", "query",
`HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`,
"/v", "ProgId").Output()
if err != nil {
return nil
}
output := string(out)
var browserName string
// Map ProgId to browser name
if strings.Contains(output, "ChromeHTML") {
browserName = "chrome"
} else if strings.Contains(output, "FirefoxURL") {
browserName = "firefox"
} else if strings.Contains(output, "MSEdgeHTM") {
browserName = "edge"
} else if strings.Contains(output, "BraveHTML") {
browserName = "brave"
}
return createWindowsIncognitoCmd(browserName, url)
}
// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers.
func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
paths := []string{
"chrome",
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
case "firefox":
if path, err := exec.LookPath("firefox"); err == nil {
return exec.Command(path, "--private-window", url)
}
case "edge":
paths := []string{
"msedge",
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--inprivate", url)
}
}
case "brave":
paths := []string{
`C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`,
`C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
}
return nil
}
// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings.
func tryDefaultBrowserLinux(url string) *exec.Cmd {
out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output()
if err != nil {
return nil
}
desktop := string(out)
var browserName string
// Map .desktop file to browser name
if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") {
browserName = "chrome"
} else if strings.Contains(desktop, "firefox") {
browserName = "firefox"
} else if strings.Contains(desktop, "chromium") {
browserName = "chromium"
} else if strings.Contains(desktop, "brave") {
browserName = "brave"
} else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") {
browserName = "edge"
}
return createLinuxIncognitoCmd(browserName, url)
}
// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers.
func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
paths := []string{"google-chrome", "google-chrome-stable"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--incognito", url)
}
}
case "firefox":
paths := []string{"firefox", "firefox-esr"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--private-window", url)
}
}
case "chromium":
paths := []string{"chromium", "chromium-browser"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--incognito", url)
}
}
case "brave":
if path, err := exec.LookPath("brave-browser"); err == nil {
return exec.Command(path, "--incognito", url)
}
case "edge":
if path, err := exec.LookPath("microsoft-edge"); err == nil {
return exec.Command(path, "--inprivate", url)
}
}
return nil
}
// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback.
func tryFallbackBrowsersIncognito(url string) *exec.Cmd {
switch runtime.GOOS {
case "darwin":
return tryFallbackBrowsersMacOS(url)
case "windows":
return tryFallbackBrowsersWindows(url)
case "linux":
return tryFallbackBrowsersLinuxChain(url)
}
return nil
}
// tryFallbackBrowsersMacOS tries known browsers on macOS.
func tryFallbackBrowsersMacOS(url string) *exec.Cmd {
// Try Chrome
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
if _, err := exec.LookPath(chromePath); err == nil {
return exec.Command(chromePath, "--incognito", url)
}
// Try Firefox
if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil {
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
}
// Try Brave
if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil {
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
}
// Try Edge
if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil {
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
}
// Last resort: try Safari with AppleScript
if cmd := tryAppleScriptSafariPrivate(url); cmd != nil {
log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)")
return cmd
}
return nil
}
// tryFallbackBrowsersWindows tries known browsers on Windows.
func tryFallbackBrowsersWindows(url string) *exec.Cmd {
// Chrome
chromePaths := []string{
"chrome",
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
}
for _, p := range chromePaths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
// Firefox
if path, err := exec.LookPath("firefox"); err == nil {
return exec.Command(path, "--private-window", url)
}
// Edge (usually available on Windows 10+)
edgePaths := []string{
"msedge",
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
}
for _, p := range edgePaths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--inprivate", url)
}
}
return nil
}
// tryFallbackBrowsersLinuxChain tries known browsers on Linux.
func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd {
type browserConfig struct {
name string
flag string
}
browsers := []browserConfig{
{"google-chrome", "--incognito"},
{"google-chrome-stable", "--incognito"},
{"chromium", "--incognito"},
{"chromium-browser", "--incognito"},
{"firefox", "--private-window"},
{"firefox-esr", "--private-window"},
{"brave-browser", "--incognito"},
{"microsoft-edge", "--inprivate"},
}
for _, b := range browsers {
if path, err := exec.LookPath(b.name); err == nil {
return exec.Command(path, b.flag, url)
}
}
return nil
}
// IsAvailable checks if the system has a command available to open a web browser.
// It verifies the presence of necessary commands for the current operating system.
//
// Returns:
// - true if a browser can be opened, false otherwise.
func IsAvailable() bool {
// First check if open-golang can work
testErr := open.Run("about:blank")
if testErr == nil {
return true
}
// Check platform-specific commands
switch runtime.GOOS {
case "darwin":

View File

@@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewQwenAuthenticator(),
sdkAuth.NewIFlowAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
)
return manager

160
internal/cmd/kiro_login.go Normal file
View File

@@ -0,0 +1,160 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoKiroLogin triggers the Kiro authentication flow with Google OAuth.
// This is the default login method (same as --kiro-google-login).
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including Prompt field
func DoKiroLogin(cfg *config.Config, options *LoginOptions) {
// Use Google login as default
DoKiroGoogleLogin(cfg, options)
}
// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth.
// This uses a custom protocol handler (kiro://) to receive the callback.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including prompts
func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
// Note: Kiro defaults to incognito mode for multi-account support.
// Users can override with --no-incognito if they want to use existing browser sessions.
manager := newAuthManager()
// Use KiroAuthenticator with Google login
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro Google authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure the protocol handler is installed")
fmt.Println("2. Complete the Google login in the browser")
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
return
}
// Save the auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro Google authentication successful!")
}
// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID.
// This uses the device code flow for AWS SSO OIDC authentication.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including prompts
func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
// Note: Kiro defaults to incognito mode for multi-account support.
// Users can override with --no-incognito if they want to use existing browser sessions.
manager := newAuthManager()
// Use KiroAuthenticator with AWS Builder ID login (device code flow)
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro AWS authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure you have an AWS Builder ID")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
return
}
// Save the auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro AWS authentication successful!")
}
// DoKiroImport imports Kiro token from Kiro IDE's token file.
// This is useful for users who have already logged in via Kiro IDE
// and want to use the same credentials in CLI Proxy API.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options (currently unused for import)
func DoKiroImport(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
// Use ImportFromKiroIDE instead of Login
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg)
if err != nil {
log.Errorf("Kiro token import failed: %v", err)
fmt.Println("\nMake sure you have logged in to Kiro IDE first:")
fmt.Println("1. Open Kiro IDE")
fmt.Println("2. Click 'Sign in with Google' (or GitHub)")
fmt.Println("3. Complete the login process")
fmt.Println("4. Run this command again")
return
}
// Save the imported auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Imported as %s\n", record.Label)
}
fmt.Println("Kiro token import successful!")
}

View File

@@ -65,20 +65,20 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
authenticator := sdkAuth.NewGeminiAuthenticator()
record, errLogin := authenticator.Login(ctx, cfg, loginOpts)
if errLogin != nil {
log.Fatalf("Gemini authentication failed: %v", errLogin)
log.Errorf("Gemini authentication failed: %v", errLogin)
return
}
storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage)
if !okStorage || storage == nil {
log.Fatal("Gemini authentication failed: unsupported token storage")
log.Error("Gemini authentication failed: unsupported token storage")
return
}
geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
if errClient != nil {
log.Fatalf("Gemini authentication failed: %v", errClient)
log.Errorf("Gemini authentication failed: %v", errClient)
return
}
@@ -86,7 +86,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Fatalf("Failed to get project list: %v", errProjects)
log.Errorf("Failed to get project list: %v", errProjects)
return
}
@@ -98,11 +98,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Fatalf("Invalid project selection: %v", errSelection)
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Fatal("No project selected; aborting login.")
log.Error("No project selected; aborting login.")
return
}
@@ -116,7 +116,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Fatalf("Failed to complete user setup: %v", errSetup)
log.Errorf("Failed to complete user setup: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
@@ -133,11 +133,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
for _, pid := range activatedProjects {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
if errCheck != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
return
}
if !isChecked {
log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
return
}
}
@@ -153,7 +153,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
savedPath, errSave := store.Save(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
log.Errorf("Failed to save token to file: %v", errSave)
return
}
@@ -555,6 +555,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil

View File

@@ -45,12 +45,13 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
service, err := builder.Build()
if err != nil {
log.Fatalf("failed to build proxy service: %v", err)
log.Errorf("failed to build proxy service: %v", err)
return
}
err = service.Run(runCtx)
if err != nil && !errors.Is(err, context.Canceled) {
log.Fatalf("proxy service exited with error: %v", err)
log.Errorf("proxy service exited with error: %v", err)
}
}

View File

@@ -29,30 +29,30 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
}
rawPath := strings.TrimSpace(keyPath)
if rawPath == "" {
log.Fatalf("vertex-import: missing service account key path")
log.Errorf("vertex-import: missing service account key path")
return
}
data, errRead := os.ReadFile(rawPath)
if errRead != nil {
log.Fatalf("vertex-import: read file failed: %v", errRead)
log.Errorf("vertex-import: read file failed: %v", errRead)
return
}
var sa map[string]any
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal)
return
}
// Validate and normalize private_key before saving
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
if errFix != nil {
log.Fatalf("vertex-import: %v", errFix)
log.Errorf("vertex-import: %v", errFix)
return
}
sa = normalizedSA
email, _ := sa["client_email"].(string)
projectID, _ := sa["project_id"].(string)
if strings.TrimSpace(projectID) == "" {
log.Fatalf("vertex-import: project_id missing in service account json")
log.Errorf("vertex-import: project_id missing in service account json")
return
}
if strings.TrimSpace(email) == "" {
@@ -92,7 +92,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
}
path, errSave := store.Save(context.Background(), record)
if errSave != nil {
log.Fatalf("vertex-import: save credential failed: %v", errSave)
log.Errorf("vertex-import: save credential failed: %v", errSave)
return
}
fmt.Printf("Vertex credentials imported: %s\n", path)

View File

@@ -20,6 +20,9 @@ import (
// Config represents the application's configuration, loaded from a YAML file.
type Config struct {
config.SDKConfig `yaml:",inline"`
// Host is the network host/interface on which the API server will bind.
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
Host string `yaml:"host" json:"-"`
// Port is the network port on which the API server will listen.
Port int `yaml:"port" json:"-"`
@@ -58,6 +61,9 @@ type Config struct {
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
@@ -80,6 +86,11 @@ type Config struct {
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
// IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode.
// This is useful when you want to login with a different account without logging out
// from your current session. Default: false.
IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"`
legacyMigrationPending bool `yaml:"-" json:"-"`
}
@@ -143,6 +154,10 @@ type AmpCode struct {
// When Amp requests a model that isn't available locally, these mappings
// allow routing to an alternative model that IS available.
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
// ForceModelMappings when true, model mappings take precedence over local API keys.
// When false (default), local API keys are used first if available.
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
@@ -240,6 +255,31 @@ type GeminiKey struct {
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
type KiroKey struct {
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"`
// AccessToken is the OAuth access token for direct configuration.
AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"`
// RefreshToken is the OAuth refresh token for token renewal.
RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"`
// ProfileArn is the AWS CodeWhisperer profile ARN.
ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"`
// Region is the AWS region (default: us-east-1).
Region string `yaml:"region,omitempty" json:"region,omitempty"`
// ProxyURL optionally overrides the global proxy for this configuration.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
// Leave empty to let API use defaults. Different values may inject different system prompts.
AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
}
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -316,10 +356,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Unmarshal the YAML data into the Config struct.
var cfg Config
// Set defaults before unmarshal so that absent keys keep defaults.
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
cfg.LoggingToFile = false
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
@@ -370,6 +412,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
// Sanitize Kiro keys: trim whitespace from credential fields
cfg.SanitizeKiroKeys()
// Sanitize OpenAI compatibility providers: drop entries without base-url
cfg.SanitizeOpenAICompatibility()
@@ -446,6 +491,22 @@ func (cfg *Config) SanitizeClaudeKeys() {
}
}
// SanitizeKiroKeys trims whitespace from Kiro credential fields.
func (cfg *Config) SanitizeKiroKeys() {
if cfg == nil || len(cfg.KiroKey) == 0 {
return
}
for i := range cfg.KiroKey {
entry := &cfg.KiroKey[i]
entry.TokenFile = strings.TrimSpace(entry.TokenFile)
entry.AccessToken = strings.TrimSpace(entry.AccessToken)
entry.RefreshToken = strings.TrimSpace(entry.RefreshToken)
entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
entry.Region = strings.TrimSpace(entry.Region)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
}
}
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
func (cfg *Config) SanitizeGeminiKeys() {
if cfg == nil {

View File

@@ -24,4 +24,7 @@ const (
// Antigravity represents the Antigravity response format identifier.
Antigravity = "antigravity"
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
Kiro = "kiro"
)

View File

@@ -56,6 +56,8 @@ type Content struct {
// Part represents a distinct piece of content within a message.
// A part can be text, inline data (like an image), a function call, or a function response.
type Part struct {
Thought bool `json:"thought,omitempty"`
// Text contains plain text content.
Text string `json:"text,omitempty"`

View File

@@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
var formatted string
// Handle nil Caller (can happen with some log entries)
callerFile := "unknown"
callerLine := 0
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
} else {
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
callerFile = filepath.Base(entry.Caller.File)
callerLine = entry.Caller.Line
}
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
buffer.WriteString(formatted)
return buffer.Bytes(), nil
@@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
func SetupBaseLogger() {
setupOnce.Do(func() {
log.SetOutput(os.Stdout)
log.SetLevel(log.InfoLevel)
log.SetReportCaller(true)
log.SetFormatter(&LogFormatter{})

View File

@@ -20,6 +20,7 @@ import (
"github.com/klauspost/compress/zstd"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
@@ -83,6 +84,26 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise
WriteStatus(status int, headers map[string][]string) error
// WriteAPIRequest writes the upstream API request details to the log.
// This should be called before WriteStatus to maintain proper log ordering.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIRequest(apiRequest []byte) error
// WriteAPIResponse writes the upstream API response details to the log.
// This should be called after the streaming response is complete.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// Close finalizes the log file and cleans up resources.
//
// Returns:
@@ -247,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
// Create streaming writer
writer := &FileStreamingLogWriter{
file: file,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
file: file,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
bufferedChunks: &bytes.Buffer{},
}
// Start async writer goroutine
@@ -603,6 +625,7 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method))
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
@@ -626,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
// It handles asynchronous writing of streaming response chunks to a file.
// All data is buffered and written in the correct order when Close is called.
type FileStreamingLogWriter struct {
// file is the file where log data is written.
file *os.File
// chunkChan is a channel for receiving response chunks to write.
// chunkChan is a channel for receiving response chunks to buffer.
chunkChan chan []byte
// closeChan is a channel for signaling when the writer is closed.
@@ -639,8 +663,23 @@ type FileStreamingLogWriter struct {
// errorChan is a channel for reporting errors during writing.
errorChan chan error
// statusWritten indicates whether the response status has been written.
// bufferedChunks stores the response chunks in order.
bufferedChunks *bytes.Buffer
// responseStatus stores the HTTP status code.
responseStatus int
// statusWritten indicates whether a non-zero status was recorded.
statusWritten bool
// responseHeaders stores the response headers.
responseHeaders map[string][]string
// apiRequest stores the upstream API request data.
apiRequest []byte
// apiResponse stores the upstream API response data.
apiResponse []byte
}
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
@@ -664,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
}
}
// WriteStatus writes the response status and headers to the log.
// WriteStatus buffers the response status and headers for later writing.
//
// Parameters:
// - status: The response status code
// - headers: The response headers
//
// Returns:
// - error: An error if writing fails, nil otherwise
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
if w.file == nil || w.statusWritten {
if status == 0 {
return nil
}
var content strings.Builder
content.WriteString("========================================\n")
content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status))
for key, values := range headers {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
w.responseStatus = status
if headers != nil {
w.responseHeaders = make(map[string][]string, len(headers))
for key, values := range headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
w.responseHeaders[key] = headerValues
}
}
content.WriteString("\n")
w.statusWritten = true
return nil
}
_, err := w.file.WriteString(content.String())
if err == nil {
w.statusWritten = true
// WriteAPIRequest buffers the upstream API request details for later writing.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error {
if len(apiRequest) == 0 {
return nil
}
return err
w.apiRequest = bytes.Clone(apiRequest)
return nil
}
// WriteAPIResponse buffers the upstream API response details for later writing.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
if len(apiResponse) == 0 {
return nil
}
w.apiResponse = bytes.Clone(apiResponse)
return nil
}
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
//
// Returns:
// - error: An error if closing fails, nil otherwise
@@ -705,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error {
close(w.chunkChan)
}
// Wait for async writer to finish
// Wait for async writer to finish buffering chunks
if w.closeChan != nil {
<-w.closeChan
w.chunkChan = nil
}
if w.file != nil {
return w.file.Close()
if w.file == nil {
return nil
}
return nil
// Write all content in the correct order
var content strings.Builder
// 1. Write API REQUEST section
if len(w.apiRequest) > 0 {
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
content.Write(w.apiRequest)
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API REQUEST ===\n")
content.Write(w.apiRequest)
content.WriteString("\n")
}
content.WriteString("\n")
}
// 2. Write API RESPONSE section
if len(w.apiResponse) > 0 {
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
content.Write(w.apiResponse)
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API RESPONSE ===\n")
content.Write(w.apiResponse)
content.WriteString("\n")
}
content.WriteString("\n")
}
// 3. Write RESPONSE section (status, headers, buffered chunks)
content.WriteString("=== RESPONSE ===\n")
if w.statusWritten {
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
}
for key, values := range w.responseHeaders {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
}
}
content.WriteString("\n")
// Write buffered response body chunks
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
content.Write(w.bufferedChunks.Bytes())
}
// Write the complete content to file
if _, err := w.file.WriteString(content.String()); err != nil {
_ = w.file.Close()
return err
}
return w.file.Close()
}
// asyncWriter runs in a goroutine to handle async chunk writing.
// It continuously reads chunks from the channel and writes them to the file.
// asyncWriter runs in a goroutine to buffer chunks from the channel.
// It continuously reads chunks from the channel and buffers them for later writing.
func (w *FileStreamingLogWriter) asyncWriter() {
defer close(w.closeChan)
for chunk := range w.chunkChan {
if w.file != nil {
_, _ = w.file.Write(chunk)
if w.bufferedChunks != nil {
w.bufferedChunks.Write(chunk)
}
}
}
@@ -752,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error
return nil
}
// WriteAPIRequest is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiRequest: The API request data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error {
return nil
}
// WriteAPIResponse is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiResponse: The API response data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
// Close is a no-op implementation that does nothing and always returns nil.
//
// Returns:

View File

@@ -693,8 +693,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Low",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Nothink",
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -719,8 +719,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Medium",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Medium",
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -732,8 +732,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 High",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 High",
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -745,8 +745,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex",
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Codex",
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -758,8 +758,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex Low",
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Codex Low",
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -771,8 +771,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex Medium",
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Codex Medium",
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -784,8 +784,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex High",
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
DisplayName: "GPT 5.1 Codex High",
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -797,8 +797,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex Mini",
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
DisplayName: "GPT 5.1 Codex Mini",
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -810,8 +810,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex Mini Medium",
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
DisplayName: "GPT 5.1 Codex Mini Medium",
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -823,8 +823,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-2025-11-12",
DisplayName: "GPT 5 Codex Mini High",
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
DisplayName: "GPT 5.1 Codex Mini High",
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -837,8 +837,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5 Codex Max",
Description: "Stable version of GPT 5 Codex Max",
DisplayName: "GPT 5.1 Codex Max",
Description: "Stable version of GPT 5.1 Codex Max",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -850,8 +850,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5 Codex Max Low",
Description: "Stable version of GPT 5 Codex Max Low",
DisplayName: "GPT 5.1 Codex Max Low",
Description: "Stable version of GPT 5.1 Codex Max Low",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -863,8 +863,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5 Codex Max Medium",
Description: "Stable version of GPT 5 Codex Max Medium",
DisplayName: "GPT 5.1 Codex Max Medium",
Description: "Stable version of GPT 5.1 Codex Max Medium",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -876,8 +876,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5 Codex Max High",
Description: "Stable version of GPT 5 Codex Max High",
DisplayName: "GPT 5.1 Codex Max High",
Description: "Stable version of GPT 5.1 Codex Max High",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -889,8 +889,8 @@ func GetOpenAIModels() []*ModelInfo {
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.1-max",
DisplayName: "GPT 5 Codex Max XHigh",
Description: "Stable version of GPT 5 Codex Max XHigh",
DisplayName: "GPT 5.1 Codex Max XHigh",
Description: "Stable version of GPT 5.1 Codex Max XHigh",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
@@ -944,7 +944,6 @@ func GetQwenModels() []*ModelInfo {
}
// GetIFlowModels returns supported models for iFlow OAuth accounts.
func GetIFlowModels() []*ModelInfo {
entries := []struct {
ID string
@@ -987,6 +986,28 @@ func GetIFlowModels() []*ModelInfo {
return models
}
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport
MaxCompletionTokens int
Name string
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
}
}
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {
@@ -1170,3 +1191,150 @@ func GetGitHubCopilotModels() []*ModelInfo {
},
}
}
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
func GetKiroModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "kiro-claude-opus-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.5",
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "kiro-claude-sonnet-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.5",
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "kiro-claude-sonnet-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4",
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "kiro-claude-haiku-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Haiku 4.5",
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
// --- Chat Variant (No tool calling, for pure conversation) ---
{
ID: "kiro-claude-opus-4.5-chat",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.5 (Chat)",
Description: "Claude Opus 4.5 for chat only (no tool calling)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
{
ID: "kiro-claude-opus-4.5-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.5 (Agentic)",
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "kiro-claude-sonnet-4.5-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)",
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
}
}
// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions.
// These models use the same API as Kiro and share the same executor.
func GetAmazonQModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "amazonq-auto",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro", // Uses Kiro executor - same API
DisplayName: "Amazon Q Auto",
Description: "Automatic model selection by Amazon Q",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-opus-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Opus 4.5",
Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-sonnet-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Sonnet 4.5",
Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-sonnet-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Sonnet 4",
Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-haiku-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Haiku 4.5",
Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
}
}

View File

@@ -309,7 +309,9 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
payload = util.ConvertThinkingLevelToBudget(payload)
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)

View File

@@ -12,6 +12,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -27,21 +28,24 @@ import (
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
streamScannerBuffer int = 20_971_520
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
streamScannerBuffer int = 20_971_520
)
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
)
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
@@ -77,6 +81,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -170,6 +176,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -366,28 +374,34 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for id := range result.Map() {
id = modelName2Alias(id)
if id != "" {
for originalName := range result.Map() {
aliasName := modelName2Alias(originalName)
if aliasName != "" {
cfg := modelConfig[aliasName]
modelName := aliasName
if cfg != nil && cfg.Name != "" {
modelName = cfg.Name
}
modelInfo := &registry.ModelInfo{
ID: id,
Name: id,
Description: id,
DisplayName: id,
Version: id,
ID: aliasName,
Name: modelName,
Description: aliasName,
DisplayName: aliasName,
Version: aliasName,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Add Thinking support for thinking models
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
modelInfo.Thinking = &registry.ThinkingSupport{
Min: 1024,
Max: 100000,
ZeroAllowed: false,
DynamicAllowed: true,
// Look up Thinking support from static config using alias name
if cfg != nil {
if cfg.Thinking != nil {
modelInfo.Thinking = cfg.Thinking
}
if cfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
@@ -509,7 +523,14 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
requestURL.WriteString(url.QueryEscape(alt))
}
payload = geminiToAntigravity(modelName, payload)
// Extract project_id from auth metadata if available
projectID := ""
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
}
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
if strings.Contains(modelName, "claude") {
@@ -526,6 +547,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
strJSON = util.DeleteKey(strJSON, "minLength")
strJSON = util.DeleteKey(strJSON, "maxLength")
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
strJSON = util.DeleteKey(strJSON, "$ref")
strJSON = util.DeleteKey(strJSON, "$defs")
paths = make([]string, 0)
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
@@ -641,7 +665,7 @@ func buildBaseURL(auth *cliproxyauth.Auth) string {
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
return baseURLs[0]
}
return antigravityBaseURLAutopush
return antigravityBaseURLDaily
}
func resolveHost(base string) string {
@@ -677,8 +701,8 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
}
return []string{
antigravityBaseURLDaily,
antigravityBaseURLAutopush,
// antigravityBaseURLProd,
// antigravityBaseURLAutopush,
antigravityBaseURLProd,
}
}
@@ -702,16 +726,22 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
return ""
}
func geminiToAntigravity(modelName string, payload []byte) []byte {
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
template, _ = sjson.Set(template, "project", generateProjectID())
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.Set(template, "project", projectID)
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
}
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
if !strings.HasPrefix(modelName, "gemini-3-") {
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
@@ -719,7 +749,7 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
}
}
if strings.HasPrefix(modelName, "claude-sonnet-") {
if strings.Contains(modelName, "claude") {
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
if funcDecl.Get("parametersJsonSchema").Exists() {
@@ -731,6 +761,8 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
})
return true
})
} else {
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
}
return []byte(template)
@@ -741,15 +773,19 @@ func generateRequestID() string {
}
func generateSessionID() string {
randSourceMutex.Lock()
n := randSource.Int63n(9_000_000_000_000_000_000)
randSourceMutex.Unlock()
return "-" + strconv.FormatInt(n, 10)
}
func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randSourceMutex.Lock()
adj := adjectives[randSource.Intn(len(adjectives))]
noun := nouns[randSource.Intn(len(nouns))]
randSourceMutex.Unlock()
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
@@ -793,3 +829,65 @@ func alias2ModelName(modelName string) string {
return modelName
}
}
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
}
budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget")
if !budget.Exists() {
return payload
}
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {
normalized = effectiveMax - 1
}
minBudget := antigravityMinThinkingBudget(model)
if minBudget > 0 && normalized >= 0 && normalized < minBudget {
// Budget is below minimum, remove thinking config entirely
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
return payload
}
if setDefaultMax {
if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil {
payload = res
}
}
}
updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
if err != nil {
return payload
}
return updated
}
// antigravityEffectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}
// antigravityMinThinkingBudget returns the minimum thinking budget for a model.
// Falls back to -1 if no model info is found.
func antigravityMinThinkingBudget(model string) int {
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.Thinking != nil {
return modelInfo.Thinking.Min
}
return -1
}

View File

@@ -1,10 +1,38 @@
package executor
import "time"
import (
"sync"
"time"
)
type codexCache struct {
ID string
Expire time.Time
}
var codexCacheMap = map[string]codexCache{}
var (
codexCacheMap = map[string]codexCache{}
codexCacheMutex sync.RWMutex
)
// getCodexCache safely retrieves a cache entry
func getCodexCache(key string) (codexCache, bool) {
codexCacheMutex.RLock()
defer codexCacheMutex.RUnlock()
cache, ok := codexCacheMap[key]
return cache, ok
}
// setCodexCache safely sets a cache entry
func setCodexCache(key string, cache codexCache) {
codexCacheMutex.Lock()
defer codexCacheMutex.Unlock()
codexCacheMap[key] = cache
}
// deleteCodexCache safely deletes a cache entry
func deleteCodexCache(key string) {
codexCacheMutex.Lock()
defer codexCacheMutex.Unlock()
delete(codexCacheMap, key)
}

View File

@@ -506,12 +506,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
if userIDResult.Exists() {
var hasKey bool
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) {
cache = codexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
codexCacheMap[key] = cache
setCodexCache(key, cache)
}
}
} else if from == "openai-response" {

View File

@@ -64,6 +64,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
@@ -199,6 +201,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("gemini-cli")
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)

View File

@@ -80,6 +80,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = applyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -169,6 +171,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = applyThinkingMetadata(body, req.Metadata, req.Model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)

View File

@@ -296,6 +296,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -391,6 +393,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -487,6 +491,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -599,6 +605,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body)

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -14,11 +15,19 @@ import (
"golang.org/x/net/proxy"
)
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
var (
httpClientCache = make(map[string]*http.Client)
httpClientCacheMutex sync.RWMutex
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
//
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
//
// Parameters:
// - ctx: The context containing optional RoundTripper
// - cfg: The application configuration
@@ -28,11 +37,6 @@ import (
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// Build cache key from proxy URL (empty string for no proxy)
cacheKey := proxyURL
// Check cache first
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[cacheKey]; ok {
httpClientCacheMutex.RUnlock()
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
}
}
return cachedClient
}
httpClientCacheMutex.RUnlock()
// Create new client
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// If we have a proxy URL configured, set up the transport
if proxyURL != "" {
transport := buildProxyTransport(proxyURL)
if transport != nil {
httpClient.Transport = transport
// Cache the client
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient
}

View File

@@ -83,7 +83,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
prompt := contentResult.Get("thinking").String()
signatureResult := contentResult.Get("signature")
signature := geminiCLIClaudeThoughtSignature
if signatureResult.Exists() {
signature = signatureResult.String()
}
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
@@ -92,10 +100,16 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionID := contentResult.Get("id").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
if strings.Contains(modelName, "claude") {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
})
} else {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
}
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
@@ -109,6 +123,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
sourceResult := contentResult.Get("source")
if sourceResult.Get("type").String() == "base64" {
inlineData := &client.InlineData{
MimeType: sourceResult.Get("media_type").String(),
Data: sourceResult.Get("data").String(),
}
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
}
}
}
contents = append(contents, clientContent)
@@ -166,7 +189,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
budget = util.NormalizeThinkingBudget(modelName, budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
@@ -181,6 +203,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
}
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")

View File

@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -36,6 +37,9 @@ type Params struct {
HasToolUse bool // Indicates if tool use was observed in the stream
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
@@ -111,8 +115,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if params.ResponseType == 2 {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -163,15 +170,16 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
if partTextResult.String() != "" {
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
}
}
}
}
@@ -212,7 +220,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "low":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "medium":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "high":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
default:
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var normalized int
var budget int
if v := tc.Get("thinkingBudget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
} else if v := tc.Get("thinking_budget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
}
@@ -82,22 +82,27 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget && normalized != 0 {
} else if setBudget && budget != 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
// This matches the official Gemini CLI behavior which always sends:
// { thinkingBudget: -1, includeThoughts: true }
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
// Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens
// This allows Claude Code and other Claude API clients to pass thinking configuration
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) {
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
}
// Temperature/top_p/top_k
// Temperature/top_p/top_k/max_tokens
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
}
@@ -107,6 +112,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
}
if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
@@ -263,7 +271,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {

View File

@@ -10,6 +10,8 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -23,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct {
FunctionIndex int
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
@@ -75,8 +80,8 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).
@@ -145,7 +150,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {

View File

@@ -331,8 +331,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
streamingEvents := make([][]byte, 0)
scanner := bufio.NewScanner(bytes.NewReader(rawJSON))
buffer := make([]byte, 20_971_520)
scanner.Buffer(buffer, 20_971_520)
// Use a smaller initial buffer (64KB) that can grow up to 20MB if needed
// This prevents allocating 20MB for every request regardless of size
scanner.Buffer(make([]byte, 64*1024), 20_971_520)
for scanner.Scan() {
line := scanner.Bytes()
// log.Debug(string(line))

View File

@@ -50,6 +50,10 @@ type ToolCallAccumulator struct {
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
var localParam any
if param == nil {
param = &localParam
}
if *param == nil {
*param = &ConvertAnthropicResponseToOpenAIParams{
CreatedAt: 0,

View File

@@ -165,7 +165,6 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
budget = util.NormalizeThinkingBudget(modelName, budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}

View File

@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -27,6 +28,9 @@ type Params struct {
ResponseIndex int // Index counter for content blocks in the streaming response
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
@@ -60,12 +64,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
var sb strings.Builder
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
sb.WriteString("event: message_start\n")
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
@@ -78,7 +82,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", messageStartTemplate))
(*param).(*Params).HasFirstResponse = true
}
@@ -101,62 +105,52 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n"
sb.WriteString("event: content_block_delta\n")
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
} else {
// Transition from another state to thinking
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
sb.WriteString("event: content_block_stop\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
(*param).(*Params).ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
sb.WriteString("event: content_block_start\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
sb.WriteString("event: content_block_delta\n")
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
(*param).(*Params).ResponseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n"
sb.WriteString("event: content_block_delta\n")
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
} else {
// Transition from another state to text content
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
sb.WriteString("event: content_block_stop\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
(*param).(*Params).ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
sb.WriteString("event: content_block_start\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
sb.WriteString("event: content_block_delta\n")
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
(*param).(*Params).ResponseType = 1 // Set state to content
}
}
@@ -169,42 +163,35 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Handle state transitions when switching to function calls
// Close any existing function call block first
if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
sb.WriteString("event: content_block_stop\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
(*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0
}
// Special handling for thinking state transition
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
// Close any other existing content block
if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
sb.WriteString("event: content_block_stop\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
(*param).(*Params).ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
sb.WriteString("event: content_block_start\n")
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
sb.WriteString("event: content_block_delta\n")
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
}
(*param).(*Params).ResponseType = 3
}
@@ -216,13 +203,13 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
// Close the final content block
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
sb.WriteString("event: content_block_stop\n")
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
sb.WriteString("\n\n\n")
// Send the final message delta with usage information and stop reason
output = output + "event: message_delta\n"
output = output + `data: `
sb.WriteString("event: message_delta\n")
sb.WriteString(`data: `)
// Create the message delta template with appropriate stop reason
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
@@ -236,11 +223,11 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
sb.WriteString(template + "\n\n\n")
}
}
return []string{output}
return []string{sb.String()}
}
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.

View File

@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "low":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "medium":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
case "high":
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
default:
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var normalized int
var budget int
if v := tc.Get("thinkingBudget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
} else if v := tc.Get("thinking_budget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
}
@@ -82,21 +82,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget && normalized != 0 {
} else if setBudget && budget != 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
// This matches the official Gemini CLI behavior which always sends:
// { thinkingBudget: -1, includeThoughts: true }
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
// Temperature/top_p/top_k
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)

View File

@@ -10,6 +10,8 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -23,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct {
FunctionIndex int
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
@@ -75,8 +80,8 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).
@@ -145,7 +150,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {

View File

@@ -158,7 +158,6 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
budget = util.NormalizeThinkingBudget(modelName, budget)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
}

View File

@@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -26,6 +27,9 @@ type Params struct {
ResponseIndex int
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
@@ -197,7 +201,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)

View File

@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "low":
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "medium":
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "high":
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
default:
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var normalized int
var budget int
if v := tc.Get("thinkingBudget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
} else if v := tc.Get("thinking_budget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
}
@@ -82,7 +82,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget && normalized != 0 {
} else if setBudget && budget != 0 {
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
}
}

View File

@@ -10,6 +10,8 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -22,6 +24,9 @@ type convertGeminiResponseToOpenAIChatParams struct {
FunctionIndex int
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses.
@@ -78,8 +83,8 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).
@@ -147,7 +152,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
@@ -230,8 +235,8 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
}
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
@@ -280,7 +285,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)

View File

@@ -400,16 +400,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "minimal":
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "low":
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 4096))
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "medium":
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
case "high":
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
default:
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
@@ -421,32 +421,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
var setBudget bool
var normalized int
var budget int
if v := tc.Get("thinking_budget"); v.Exists() {
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
budget = int(v.Int())
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
setBudget = true
}
if v := tc.Get("include_thoughts"); v.Exists() {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
} else if setBudget {
if normalized != 0 {
if budget != 0 {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
}
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
// This matches the official Gemini CLI behavior which always sends:
// { thinkingBudget: -1, includeThoughts: true }
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
if !gjson.Get(out, "generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
// log.Debugf("Applied default thinkingConfig for gemini-3-pro-preview (matches Gemini CLI): thinkingBudget=-1, include_thoughts=true")
}
result := []byte(out)
result = common.AttachDefaultSafetySettings(result, "safetySettings")
return result

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -37,6 +38,12 @@ type geminiToResponsesState struct {
FuncCallIDs map[int]string
}
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
var responseIDCounter uint64
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
var funcCallIDCounter uint64
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
@@ -205,7 +212,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
st.FuncArgsBuf[idx] = &strings.Builder{}
}
if st.FuncCallIDs[idx] == "" {
st.FuncCallIDs[idx] = fmt.Sprintf("call_%d", time.Now().UnixNano())
st.FuncCallIDs[idx] = fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
}
st.FuncNames[idx] = name
@@ -464,7 +471,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// id: prefer provider responseId, otherwise synthesize
id := root.Get("responseId").String()
if id == "" {
id = fmt.Sprintf("resp_%x", time.Now().UnixNano())
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
// Normalize to response-style id (prefix resp_ if missing)
if !strings.HasPrefix(id, "resp_") {
@@ -575,7 +582,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if fc := p.Get("functionCall"); fc.Exists() {
name := fc.Get("name").String()
args := fc.Get("args")
callID := fmt.Sprintf("call_%x", time.Now().UnixNano())
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
outputs = append(outputs, map[string]interface{}{
"id": fmt.Sprintf("fc_%s", callID),
"type": "function_call",

View File

@@ -33,4 +33,7 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions"
)

View File

@@ -0,0 +1,19 @@
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Kiro,
ConvertClaudeRequestToKiro,
interfaces.TranslateResponse{
Stream: ConvertKiroResponseToClaude,
NonStream: ConvertKiroResponseToClaudeNonStream,
},
)
}

View File

@@ -0,0 +1,27 @@
// Package claude provides translation between Kiro and Claude formats.
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
// translations are pass-through.
package claude
import (
"bytes"
"context"
)
// ConvertClaudeRequestToKiro converts Claude request to Kiro format.
// Since Kiro uses Claude format internally, this is mostly a pass-through.
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
return bytes.Clone(inputRawJSON)
}
// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format.
// Kiro executor already generates complete SSE format with "event:" prefix,
// so this is a simple pass-through.
func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
return []string{string(rawResponse)}
}
// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format.
func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
return string(rawResponse)
}

View File

@@ -0,0 +1,19 @@
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Kiro,
ConvertOpenAIRequestToKiro,
interfaces.TranslateResponse{
Stream: ConvertKiroResponseToOpenAI,
NonStream: ConvertKiroResponseToOpenAINonStream,
},
)
}

View File

@@ -0,0 +1,319 @@
// Package chat_completions provides request translation from OpenAI to Kiro format.
package chat_completions
import (
"bytes"
"encoding/json"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format.
// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format.
// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result.
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
root := gjson.ParseBytes(rawJSON)
// Build Claude-compatible request
out := `{"model":"","max_tokens":32000,"messages":[]}`
// Set model
out, _ = sjson.Set(out, "model", modelName)
// Copy max_tokens if present
if v := root.Get("max_tokens"); v.Exists() {
out, _ = sjson.Set(out, "max_tokens", v.Int())
}
// Copy temperature if present
if v := root.Get("temperature"); v.Exists() {
out, _ = sjson.Set(out, "temperature", v.Float())
}
// Copy top_p if present
if v := root.Get("top_p"); v.Exists() {
out, _ = sjson.Set(out, "top_p", v.Float())
}
// Convert OpenAI tools to Claude tools format
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
claudeTools := make([]interface{}, 0)
for _, tool := range tools.Array() {
if tool.Get("type").String() == "function" {
fn := tool.Get("function")
claudeTool := map[string]interface{}{
"name": fn.Get("name").String(),
"description": fn.Get("description").String(),
}
// Convert parameters to input_schema
if params := fn.Get("parameters"); params.Exists() {
claudeTool["input_schema"] = params.Value()
} else {
claudeTool["input_schema"] = map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
claudeTools = append(claudeTools, claudeTool)
}
}
if len(claudeTools) > 0 {
out, _ = sjson.Set(out, "tools", claudeTools)
}
}
// Process messages
messages := root.Get("messages")
if messages.Exists() && messages.IsArray() {
claudeMessages := make([]interface{}, 0)
var systemPrompt string
// Track pending tool results to merge with next user message
var pendingToolResults []map[string]interface{}
for _, msg := range messages.Array() {
role := msg.Get("role").String()
content := msg.Get("content")
if role == "system" {
// Extract system message
if content.IsArray() {
for _, part := range content.Array() {
if part.Get("type").String() == "text" {
systemPrompt += part.Get("text").String() + "\n"
}
}
} else {
systemPrompt = content.String()
}
continue
}
if role == "tool" {
// OpenAI tool message -> Claude tool_result content block
toolCallID := msg.Get("tool_call_id").String()
toolContent := content.String()
toolResult := map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolCallID,
}
// Handle content - can be string or structured
if content.IsArray() {
contentParts := make([]interface{}, 0)
for _, part := range content.Array() {
if part.Get("type").String() == "text" {
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": part.Get("text").String(),
})
}
}
toolResult["content"] = contentParts
} else {
toolResult["content"] = toolContent
}
pendingToolResults = append(pendingToolResults, toolResult)
continue
}
claudeMsg := map[string]interface{}{
"role": role,
}
// Handle assistant messages with tool_calls
if role == "assistant" && msg.Get("tool_calls").Exists() {
contentParts := make([]interface{}, 0)
// Add text content if present
if content.Exists() && content.String() != "" {
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": content.String(),
})
}
// Convert tool_calls to tool_use blocks
for _, toolCall := range msg.Get("tool_calls").Array() {
toolUseID := toolCall.Get("id").String()
fnName := toolCall.Get("function.name").String()
fnArgs := toolCall.Get("function.arguments").String()
// Parse arguments JSON
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil {
argsMap = map[string]interface{}{"raw": fnArgs}
}
contentParts = append(contentParts, map[string]interface{}{
"type": "tool_use",
"id": toolUseID,
"name": fnName,
"input": argsMap,
})
}
claudeMsg["content"] = contentParts
claudeMessages = append(claudeMessages, claudeMsg)
continue
}
// Handle user messages - may need to include pending tool results
if role == "user" && len(pendingToolResults) > 0 {
contentParts := make([]interface{}, 0)
// Add pending tool results first
for _, tr := range pendingToolResults {
contentParts = append(contentParts, tr)
}
pendingToolResults = nil
// Add user content
if content.IsArray() {
for _, part := range content.Array() {
partType := part.Get("type").String()
if partType == "text" {
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": part.Get("text").String(),
})
} else if partType == "image_url" {
imageURL := part.Get("image_url.url").String()
// Check if it's base64 format (data:image/png;base64,xxxxx)
if strings.HasPrefix(imageURL, "data:") {
// Parse data URL format
// Format: data:image/png;base64,xxxxx
commaIdx := strings.Index(imageURL, ",")
if commaIdx != -1 {
// Extract media_type (e.g., "image/png")
header := imageURL[5:commaIdx] // Remove "data:" prefix
mediaType := header
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
mediaType = header[:semiIdx]
}
// Extract base64 data
base64Data := imageURL[commaIdx+1:]
contentParts = append(contentParts, map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "base64",
"media_type": mediaType,
"data": base64Data,
},
})
}
} else {
// Regular URL format - keep original logic
contentParts = append(contentParts, map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "url",
"url": imageURL,
},
})
}
}
}
} else if content.String() != "" {
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": content.String(),
})
}
claudeMsg["content"] = contentParts
claudeMessages = append(claudeMessages, claudeMsg)
continue
}
// Handle regular content
if content.IsArray() {
contentParts := make([]interface{}, 0)
for _, part := range content.Array() {
partType := part.Get("type").String()
if partType == "text" {
contentParts = append(contentParts, map[string]interface{}{
"type": "text",
"text": part.Get("text").String(),
})
} else if partType == "image_url" {
imageURL := part.Get("image_url.url").String()
// Check if it's base64 format (data:image/png;base64,xxxxx)
if strings.HasPrefix(imageURL, "data:") {
// Parse data URL format
// Format: data:image/png;base64,xxxxx
commaIdx := strings.Index(imageURL, ",")
if commaIdx != -1 {
// Extract media_type (e.g., "image/png")
header := imageURL[5:commaIdx] // Remove "data:" prefix
mediaType := header
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
mediaType = header[:semiIdx]
}
// Extract base64 data
base64Data := imageURL[commaIdx+1:]
contentParts = append(contentParts, map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "base64",
"media_type": mediaType,
"data": base64Data,
},
})
}
} else {
// Regular URL format - keep original logic
contentParts = append(contentParts, map[string]interface{}{
"type": "image",
"source": map[string]interface{}{
"type": "url",
"url": imageURL,
},
})
}
}
}
claudeMsg["content"] = contentParts
} else {
claudeMsg["content"] = content.String()
}
claudeMessages = append(claudeMessages, claudeMsg)
}
// If there are pending tool results without a following user message,
// create a user message with just the tool results
if len(pendingToolResults) > 0 {
contentParts := make([]interface{}, 0)
for _, tr := range pendingToolResults {
contentParts = append(contentParts, tr)
}
claudeMessages = append(claudeMessages, map[string]interface{}{
"role": "user",
"content": contentParts,
})
}
out, _ = sjson.Set(out, "messages", claudeMessages)
if systemPrompt != "" {
out, _ = sjson.Set(out, "system", systemPrompt)
}
}
// Set stream
out, _ = sjson.Set(out, "stream", stream)
return []byte(out)
}

View File

@@ -0,0 +1,360 @@
// Package chat_completions provides response translation from Kiro to OpenAI format.
package chat_completions
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format.
// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta,
// content_block_stop, message_delta, and message_stop.
// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON.
func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
raw := string(rawResponse)
var results []string
// Handle SSE format: extract JSON from "data: " lines
// Input format: "event: message_start\ndata: {...}"
lines := strings.Split(raw, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data: ") {
jsonPart := strings.TrimPrefix(line, "data: ")
chunks := convertClaudeEventToOpenAI(jsonPart, model)
results = append(results, chunks...)
} else if strings.HasPrefix(line, "{") {
// Raw JSON (backward compatibility)
chunks := convertClaudeEventToOpenAI(line, model)
results = append(results, chunks...)
}
}
return results
}
// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format
func convertClaudeEventToOpenAI(jsonStr string, model string) []string {
root := gjson.Parse(jsonStr)
var results []string
eventType := root.Get("type").String()
switch eventType {
case "message_start":
// Initial message event - emit initial chunk with role
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"role": "assistant",
"content": "",
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
return results
case "content_block_start":
// Start of a content block (text or tool_use)
blockType := root.Get("content_block.type").String()
index := int(root.Get("index").Int())
if blockType == "tool_use" {
// Start of tool_use block
toolUseID := root.Get("content_block.id").String()
toolName := root.Get("content_block.name").String()
toolCall := map[string]interface{}{
"index": index,
"id": toolUseID,
"type": "function",
"function": map[string]interface{}{
"name": toolName,
"arguments": "",
},
}
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{toolCall},
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
return results
case "content_block_delta":
index := int(root.Get("index").Int())
deltaType := root.Get("delta.type").String()
if deltaType == "text_delta" {
// Text content delta
contentDelta := root.Get("delta.text").String()
if contentDelta != "" {
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"content": contentDelta,
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
} else if deltaType == "input_json_delta" {
// Tool input delta (streaming arguments)
partialJSON := root.Get("delta.partial_json").String()
if partialJSON != "" {
toolCall := map[string]interface{}{
"index": index,
"function": map[string]interface{}{
"arguments": partialJSON,
},
}
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{toolCall},
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
}
return results
case "content_block_stop":
// End of content block - no output needed for OpenAI format
return results
case "message_delta":
// Final message delta with stop_reason
stopReason := root.Get("delta.stop_reason").String()
if stopReason != "" {
finishReason := "stop"
if stopReason == "tool_use" {
finishReason = "tool_calls"
} else if stopReason == "end_turn" {
finishReason = "stop"
} else if stopReason == "max_tokens" {
finishReason = "length"
}
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": finishReason,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
return results
case "message_stop":
// End of message - could emit [DONE] marker
return results
}
// Fallback: handle raw content for backward compatibility
var contentDelta string
if delta := root.Get("delta.text"); delta.Exists() {
contentDelta = delta.String()
} else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" {
contentDelta = content.String()
}
if contentDelta != "" {
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"content": contentDelta,
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
// Handle tool_use content blocks (Claude format) - fallback
toolUses := root.Get("delta.tool_use")
if !toolUses.Exists() {
toolUses = root.Get("tool_use")
}
if toolUses.Exists() && toolUses.IsObject() {
inputJSON := toolUses.Get("input").String()
if inputJSON == "" {
if inputObj := toolUses.Get("input"); inputObj.Exists() {
inputBytes, _ := json.Marshal(inputObj.Value())
inputJSON = string(inputBytes)
}
}
toolCall := map[string]interface{}{
"index": 0,
"id": toolUses.Get("id").String(),
"type": "function",
"function": map[string]interface{}{
"name": toolUses.Get("name").String(),
"arguments": inputJSON,
},
}
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"tool_calls": []map[string]interface{}{toolCall},
},
"finish_reason": nil,
},
},
}
result, _ := json.Marshal(response)
results = append(results, string(result))
}
return results
}
// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format.
func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
root := gjson.ParseBytes(rawResponse)
var content string
var toolCalls []map[string]interface{}
contentArray := root.Get("content")
if contentArray.IsArray() {
for _, item := range contentArray.Array() {
itemType := item.Get("type").String()
if itemType == "text" {
content += item.Get("text").String()
} else if itemType == "tool_use" {
// Convert Claude tool_use to OpenAI tool_calls format
inputJSON := item.Get("input").String()
if inputJSON == "" {
// If input is an object, marshal it
if inputObj := item.Get("input"); inputObj.Exists() {
inputBytes, _ := json.Marshal(inputObj.Value())
inputJSON = string(inputBytes)
}
}
toolCall := map[string]interface{}{
"id": item.Get("id").String(),
"type": "function",
"function": map[string]interface{}{
"name": item.Get("name").String(),
"arguments": inputJSON,
},
}
toolCalls = append(toolCalls, toolCall)
}
}
} else {
content = root.Get("content").String()
}
inputTokens := root.Get("usage.input_tokens").Int()
outputTokens := root.Get("usage.output_tokens").Int()
message := map[string]interface{}{
"role": "assistant",
"content": content,
}
// Add tool_calls if present
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
response := map[string]interface{}{
"id": "chatcmpl-" + uuid.New().String()[:24],
"object": "chat.completion",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]interface{}{
{
"index": 0,
"message": message,
"finish_reason": finishReason,
},
},
"usage": map[string]interface{}{
"prompt_tokens": inputTokens,
"completion_tokens": outputTokens,
"total_tokens": inputTokens + outputTokens,
},
}
result, _ := json.Marshal(response)
return string(result)
}

View File

@@ -8,6 +8,7 @@ package claude
import (
"bytes"
"encoding/json"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -242,11 +243,12 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
switch partType {
case "text":
if !part.Get("text").Exists() {
text := part.Get("text").String()
if strings.TrimSpace(text) == "" {
return "", false
}
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", part.Get("text").String())
textContent, _ = sjson.Set(textContent, "text", text)
return textContent, true
case "image":

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
@@ -41,6 +42,9 @@ type oaiToResponsesState struct {
UsageSeen bool
}
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
var responseIDCounter uint64
func emitRespEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
@@ -590,7 +594,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
// id: use provider id if present, otherwise synthesize
id := root.Get("id").String()
if id == "" {
id = fmt.Sprintf("resp_%x", time.Now().UnixNano())
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
resp, _ = sjson.Set(resp, "id", id)

View File

@@ -207,6 +207,47 @@ func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) {
return budgetPtr, includePtr, matched
}
// modelsWithDefaultThinking lists models that should have thinking enabled by default
// when no explicit thinkingConfig is provided.
var modelsWithDefaultThinking = map[string]bool{
"gemini-3-pro-preview": true,
}
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
func ModelHasDefaultThinking(model string) bool {
return modelsWithDefaultThinking[model]
}
// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it.
// For standard Gemini API format (generationConfig.thinkingConfig path).
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
if !ModelHasDefaultThinking(model) {
return body
}
if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() {
return body
}
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1)
updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true)
return updated
}
// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it.
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
if !ModelHasDefaultThinking(model) {
return body
}
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
return body
}
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true)
return updated
}
// StripThinkingConfigIfUnsupported removes thinkingConfig from the request body
// when the target model does not advertise Thinking capability. It cleans both
// standard Gemini and Gemini CLI JSON envelopes. This acts as a final safety net
@@ -223,6 +264,32 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
return updated
}
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
// request body (generationConfig.thinkingConfig.thinkingBudget path).
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
budget := gjson.GetBytes(body, budgetPath)
if !budget.Exists() {
return body
}
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
return updated
}
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
budget := gjson.GetBytes(body, budgetPath)
if !budget.Exists() {
return body
}
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
return updated
}
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
// and converts it to "thinkingBudget".
// "high" -> 32768

View File

@@ -21,6 +21,7 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"gopkg.in/yaml.v3"
@@ -176,6 +177,9 @@ func (w *Watcher) Start(ctx context.Context) error {
}
log.Debugf("watching auth directory: %s", w.authDir)
// Watch Kiro IDE token file directory for automatic token updates
w.watchKiroIDETokenFile()
// Start the event processing goroutine
go w.processEvents(ctx)
@@ -184,6 +188,31 @@ func (w *Watcher) Start(ctx context.Context) error {
return nil
}
// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher.
// This enables automatic detection of token updates from Kiro IDE.
func (w *Watcher) watchKiroIDETokenFile() {
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
return
}
// Kiro IDE stores tokens in ~/.aws/sso/cache/
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
// Check if directory exists
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
return
}
if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
return
}
log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
}
// Stop stops the file watcher
func (w *Watcher) Stop() error {
w.stopDispatch()
@@ -744,10 +773,20 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
if !isConfigEvent && !isAuthJSON {
// Check for Kiro IDE token file changes
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
return
}
// Handle Kiro IDE token file changes
if isKiroIDEToken {
w.handleKiroIDETokenChange(event)
return
}
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
@@ -805,6 +844,51 @@ func (w *Watcher) scheduleConfigReload() {
})
}
// isKiroIDETokenFile checks if the given path is the Kiro IDE token file.
func (w *Watcher) isKiroIDETokenFile(path string) bool {
// Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/
// Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes)
normalized := filepath.ToSlash(path)
return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
}
// handleKiroIDETokenChange processes changes to the Kiro IDE token file.
// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth.
func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
// Token file removed - wait briefly for potential atomic replace
time.Sleep(replaceCheckDelay)
if _, statErr := os.Stat(event.Name); statErr != nil {
log.Debugf("Kiro IDE token file removed: %s", event.Name)
return
}
}
// Try to load the updated token
tokenData, err := kiroauth.LoadKiroIDEToken()
if err != nil {
log.Debugf("failed to load Kiro IDE token after change: %v", err)
return
}
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
// Trigger auth state refresh to pick up the new token
w.refreshAuthState()
// Notify callback if set
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
if w.reloadCallback != nil && cfg != nil {
log.Debugf("triggering server update callback after Kiro IDE token change")
w.reloadCallback(cfg)
}
}
func (w *Watcher) reloadConfigIfChanged() {
data, err := os.ReadFile(w.configPath)
if err != nil {
@@ -1181,6 +1265,82 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
out = append(out, a)
}
// Kiro (AWS CodeWhisperer) -> synthesize auths
var kAuth *kiroauth.KiroAuth
if len(cfg.KiroKey) > 0 {
kAuth = kiroauth.NewKiroAuth(cfg)
}
for i := range cfg.KiroKey {
kk := cfg.KiroKey[i]
var accessToken, profileArn, refreshToken string
// Try to load from token file first
if kk.TokenFile != "" && kAuth != nil {
tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile)
if err != nil {
log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err)
} else {
accessToken = tokenData.AccessToken
profileArn = tokenData.ProfileArn
refreshToken = tokenData.RefreshToken
}
}
// Override with direct config values if provided
if kk.AccessToken != "" {
accessToken = kk.AccessToken
}
if kk.ProfileArn != "" {
profileArn = kk.ProfileArn
}
if kk.RefreshToken != "" {
refreshToken = kk.RefreshToken
}
if accessToken == "" {
log.Warnf("kiro config[%d] missing access_token, skipping", i)
continue
}
// profileArn is optional for AWS Builder ID users
id, token := idGen.next("kiro:token", accessToken, profileArn)
attrs := map[string]string{
"source": fmt.Sprintf("config:kiro[%s]", token),
"access_token": accessToken,
}
if profileArn != "" {
attrs["profile_arn"] = profileArn
}
if kk.Region != "" {
attrs["region"] = kk.Region
}
if kk.AgentTaskType != "" {
attrs["agent_task_type"] = kk.AgentTaskType
}
if refreshToken != "" {
attrs["refresh_token"] = refreshToken
}
proxyURL := strings.TrimSpace(kk.ProxyURL)
a := &coreauth.Auth{
ID: id,
Provider: "kiro",
Label: "kiro-token",
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
if refreshToken != "" {
if a.Metadata == nil {
a.Metadata = make(map[string]any)
}
a.Metadata["refresh_token"] = refreshToken
}
out = append(out, a)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
@@ -1287,7 +1447,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
}
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
entries, _ := os.ReadDir(w.authDir)
log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir)
entries, readErr := os.ReadDir(w.authDir)
if readErr != nil {
log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr)
}
log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries))
for _, e := range entries {
if e.IsDir() {
continue
@@ -1306,9 +1471,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
continue
}
t, _ := metadata["type"].(string)
// Detect Kiro auth files by auth_method field (they don't have "type" field)
if t == "" {
if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" {
t = "kiro"
log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name)
}
}
if t == "" {
log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name)
continue
}
log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t)
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
@@ -1317,6 +1493,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// For Kiro, use provider field as label if available
if provider == "kiro" {
if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" {
label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider))
}
}
// Use relative path under authDir as ID to stay consistent with the file-based token store
id := full
if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" {
@@ -1342,6 +1524,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
CreatedAt: now,
UpdatedAt: now,
}
// Set NextRefreshAfter for Kiro auth based on expires_at
if provider == "kiro" {
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil {
// Refresh 30 minutes before expiry
a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
}
}
}
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
if provider == "gemini-cli" {
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {

View File

@@ -48,8 +48,24 @@ func (h *GeminiAPIHandler) Models() []map[string]any {
// GeminiModels handles the Gemini models listing endpoint.
// It returns a JSON response containing available Gemini models and their specifications.
func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
rawModels := h.Models()
normalizedModels := make([]map[string]any, 0, len(rawModels))
defaultMethods := []string{"generateContent"}
for _, model := range rawModels {
normalizedModel := make(map[string]any, len(model))
for k, v := range model {
normalizedModel[k] = v
}
if name, ok := normalizedModel["name"].(string); ok && name != "" && !strings.HasPrefix(name, "models/") {
normalizedModel["name"] = "models/" + name
}
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
normalizedModel["supportedGenerationMethods"] = defaultMethods
}
normalizedModels = append(normalizedModels, normalizedModel)
}
c.JSON(http.StatusOK, gin.H{
"models": h.Models(),
"models": normalizedModels,
})
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -50,7 +51,25 @@ type BaseAPIHandler struct {
Cfg *config.SDKConfig
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
OpenAICompatProviders []string
openAICompatProviders []string
openAICompatMutex sync.RWMutex
}
// GetOpenAICompatProviders safely returns a copy of the provider names
func (h *BaseAPIHandler) GetOpenAICompatProviders() []string {
h.openAICompatMutex.RLock()
defer h.openAICompatMutex.RUnlock()
result := make([]string, len(h.openAICompatProviders))
copy(result, h.openAICompatProviders)
return result
}
// SetOpenAICompatProviders safely sets the provider names
func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
h.openAICompatMutex.Lock()
defer h.openAICompatMutex.Unlock()
h.openAICompatProviders = make([]string, len(providers))
copy(h.openAICompatProviders, providers)
}
// NewBaseAPIHandlers creates a new API handlers instance.
@@ -63,11 +82,12 @@ type BaseAPIHandler struct {
// Returns:
// - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
return &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
OpenAICompatProviders: openAICompatProviders,
h := &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
}
h.SetOpenAICompatProviders(openAICompatProviders)
return h
}
// UpdateClients updates the handlers' client list and configuration.
@@ -363,7 +383,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode
}
// Check if the provider is a configured openai-compatibility provider
for _, pName := range h.OpenAICompatProviders {
for _, pName := range h.GetOpenAICompatProviders() {
if pName == providerPart {
return providerPart, modelPart, true
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -127,6 +128,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
}
}
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
projectID := ""
if tokenResp.AccessToken != "" {
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
}
}
now := time.Now()
metadata := map[string]any{
"type": "antigravity",
@@ -139,6 +152,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
if email != "" {
metadata["email"] = email
}
if projectID != "" {
metadata["project_id"] = projectID
}
fileName := sanitizeAntigravityFileName(email)
label := email
@@ -147,6 +163,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
}
fmt.Println("Antigravity authentication successful")
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
}
return &coreauth.Auth{
ID: fileName,
Provider: "antigravity",
@@ -291,3 +310,89 @@ func sanitizeAntigravityFileName(email string) string {
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
// Antigravity API constants for project discovery
const (
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
antigravityAPIVersion = "v1internal"
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)
// FetchAntigravityProjectID exposes project discovery for external callers.
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
}
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
// Call loadCodeAssist to get the project
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
return "", fmt.Errorf("no cloudaicompanionProject in response")
}
return projectID, nil
}

357
sdk/auth/kiro.go Normal file
View File

@@ -0,0 +1,357 @@
package auth
import (
"context"
"fmt"
"strings"
"time"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// extractKiroIdentifier extracts a meaningful identifier for file naming.
// Returns account name if provided, otherwise profile ARN ID.
// All extracted values are sanitized to prevent path injection attacks.
func extractKiroIdentifier(accountName, profileArn string) string {
// Priority 1: Use account name if provided
if accountName != "" {
return kiroauth.SanitizeEmailForFilename(accountName)
}
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
if profileArn != "" {
parts := strings.Split(profileArn, "/")
if len(parts) >= 2 {
// Sanitize the ARN component to prevent path traversal
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
}
}
// Fallback: timestamp
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
}
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
type KiroAuthenticator struct{}
// NewKiroAuthenticator constructs a Kiro authenticator.
func NewKiroAuthenticator() *KiroAuthenticator {
return &KiroAuthenticator{}
}
// Provider returns the provider key for the authenticator.
func (a *KiroAuthenticator) Provider() string {
return "kiro"
}
// RefreshLead indicates how soon before expiry a refresh should be attempted.
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
d := 30 * time.Minute
return &d
}
// Login performs OAuth login for Kiro with AWS Builder ID.
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("kiro auth: configuration is required")
}
oauth := kiroauth.NewKiroOAuth(cfg)
// Use AWS Builder ID device code flow
tokenData, err := oauth.LoginWithBuilderID(ctx)
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
}
// Parse expires_at
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(1 * time.Hour)
}
// Extract identifier for file naming
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
now := time.Now()
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Label: "kiro-aws",
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenData.AccessToken,
"refresh_token": tokenData.RefreshToken,
"profile_arn": tokenData.ProfileArn,
"expires_at": tokenData.ExpiresAt,
"auth_method": tokenData.AuthMethod,
"provider": tokenData.Provider,
"client_id": tokenData.ClientID,
"client_secret": tokenData.ClientSecret,
"email": tokenData.Email,
},
Attributes: map[string]string{
"profile_arn": tokenData.ProfileArn,
"source": "aws-builder-id",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
}
if tokenData.Email != "" {
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
} else {
fmt.Println("\n✓ Kiro authentication completed successfully!")
}
return record, nil
}
// LoginWithGoogle performs OAuth login for Kiro with Google.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("kiro auth: configuration is required")
}
oauth := kiroauth.NewKiroOAuth(cfg)
// Use Google OAuth flow with protocol handler
tokenData, err := oauth.LoginWithGoogle(ctx)
if err != nil {
return nil, fmt.Errorf("google login failed: %w", err)
}
// Parse expires_at
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(1 * time.Hour)
}
// Extract identifier for file naming
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
now := time.Now()
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Label: "kiro-google",
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenData.AccessToken,
"refresh_token": tokenData.RefreshToken,
"profile_arn": tokenData.ProfileArn,
"expires_at": tokenData.ExpiresAt,
"auth_method": tokenData.AuthMethod,
"provider": tokenData.Provider,
"email": tokenData.Email,
},
Attributes: map[string]string{
"profile_arn": tokenData.ProfileArn,
"source": "google-oauth",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
}
if tokenData.Email != "" {
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
} else {
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
}
return record, nil
}
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("kiro auth: configuration is required")
}
oauth := kiroauth.NewKiroOAuth(cfg)
// Use GitHub OAuth flow with protocol handler
tokenData, err := oauth.LoginWithGitHub(ctx)
if err != nil {
return nil, fmt.Errorf("github login failed: %w", err)
}
// Parse expires_at
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(1 * time.Hour)
}
// Extract identifier for file naming
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
now := time.Now()
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Label: "kiro-github",
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenData.AccessToken,
"refresh_token": tokenData.RefreshToken,
"profile_arn": tokenData.ProfileArn,
"expires_at": tokenData.ExpiresAt,
"auth_method": tokenData.AuthMethod,
"provider": tokenData.Provider,
"email": tokenData.Email,
},
Attributes: map[string]string{
"profile_arn": tokenData.ProfileArn,
"source": "github-oauth",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
}
if tokenData.Email != "" {
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
} else {
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
}
return record, nil
}
// ImportFromKiroIDE imports token from Kiro IDE's token file.
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
tokenData, err := kiroauth.LoadKiroIDEToken()
if err != nil {
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
}
// Parse expires_at
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(1 * time.Hour)
}
// Extract email from JWT if not already set (for imported tokens)
if tokenData.Email == "" {
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
}
// Extract identifier for file naming
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
// Sanitize provider to prevent path traversal (defense-in-depth)
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
if provider == "" {
provider = "imported" // Fallback for legacy tokens without provider
}
now := time.Now()
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Label: fmt.Sprintf("kiro-%s", provider),
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenData.AccessToken,
"refresh_token": tokenData.RefreshToken,
"profile_arn": tokenData.ProfileArn,
"expires_at": tokenData.ExpiresAt,
"auth_method": tokenData.AuthMethod,
"provider": tokenData.Provider,
"email": tokenData.Email,
},
Attributes: map[string]string{
"profile_arn": tokenData.ProfileArn,
"source": "kiro-ide-import",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
}
// Display the email if extracted
if tokenData.Email != "" {
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
} else {
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
}
return record, nil
}
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
if auth == nil || auth.Metadata == nil {
return nil, fmt.Errorf("invalid auth record")
}
refreshToken, ok := auth.Metadata["refresh_token"].(string)
if !ok || refreshToken == "" {
return nil, fmt.Errorf("refresh token not found")
}
clientID, _ := auth.Metadata["client_id"].(string)
clientSecret, _ := auth.Metadata["client_secret"].(string)
authMethod, _ := auth.Metadata["auth_method"].(string)
var tokenData *kiroauth.KiroTokenData
var err error
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
} else {
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
oauth := kiroauth.NewKiroOAuth(cfg)
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
}
if err != nil {
return nil, fmt.Errorf("token refresh failed: %w", err)
}
// Parse expires_at
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
expiresAt = time.Now().Add(1 * time.Hour)
}
// Clone auth to avoid mutating the input parameter
updated := auth.Clone()
now := time.Now()
updated.UpdatedAt = now
updated.LastRefreshedAt = now
updated.Metadata["access_token"] = tokenData.AccessToken
updated.Metadata["refresh_token"] = tokenData.RefreshToken
updated.Metadata["expires_at"] = tokenData.ExpiresAt
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
return updated, nil
}

View File

@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
}
return record, savedPath, nil
}
// SaveAuth persists an auth record directly without going through the login flow.
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
if m.store == nil {
return "", fmt.Errorf("no store configured")
}
if cfg != nil {
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
}
return m.store.Save(context.Background(), record)
}

View File

@@ -14,6 +14,7 @@ func init() {
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
}

View File

@@ -379,6 +379,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
case "kiro":
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
case "github-copilot":
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
default:
@@ -500,7 +502,7 @@ func (s *Service) Run(ctx context.Context) error {
}()
time.Sleep(100 * time.Millisecond)
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
if s.hooks.OnAfterStart != nil {
s.hooks.OnAfterStart(s)
@@ -725,6 +727,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "github-copilot":
models = registry.GetGitHubCopilotModels()
models = applyExcludedModels(models, excluded)
case "kiro":
models = registry.GetKiroModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -783,7 +788,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
Created: time.Now().Unix(),
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: m.Name,
DisplayName: modelID,
})
}
// Register and return

827
test/amp_management_test.go Normal file
View File

@@ -0,0 +1,827 @@
package test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func init() {
gin.SetMode(gin.TestMode)
}
// newAmpTestHandler creates a test handler with default ampcode configuration.
func newAmpTestHandler(t *testing.T) (*management.Handler, string) {
t.Helper()
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: "https://example.com",
UpstreamAPIKey: "test-api-key-12345",
RestrictManagementToLocalhost: true,
ForceModelMappings: false,
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "gemini-pro"},
},
},
}
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
return h, configPath
}
// setupAmpRouter creates a test router with all ampcode management endpoints.
func setupAmpRouter(h *management.Handler) *gin.Engine {
r := gin.New()
mgmt := r.Group("/v0/management")
{
mgmt.GET("/ampcode", h.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings)
}
return r
}
// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config.
func TestGetAmpCode(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
ampcode := resp["ampcode"]
if ampcode.UpstreamURL != "https://example.com" {
t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL)
}
if len(ampcode.ModelMappings) != 1 {
t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings))
}
}
// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL.
func TestGetAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-url"] != "https://example.com" {
t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"])
}
}
// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL.
func TestPutAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-upstream.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL.
func TestDeleteAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key.
func TestGetAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
key := resp["upstream-api-key"].(string)
if key != "test-api-key-12345" {
t.Errorf("expected key %q, got %q", "test-api-key-12345", key)
}
}
// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key.
func TestPutAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-key"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting.
func TestGetAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["restrict-management-to-localhost"] != true {
t.Error("expected restrict-management-to-localhost to be true")
}
}
// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting.
func TestPutAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings.
func TestGetAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping, got %d", len(mappings))
}
if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" {
t.Errorf("unexpected mapping: %+v", mappings[0])
}
}
// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings.
func TestPutAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones.
func TestPatchAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field.
func TestDeleteAmpModelMappings_Specific(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": ["gpt-4"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings.
func TestDeleteAmpModelMappings_All(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.
func TestGetAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["force-model-mappings"] != false {
t.Error("expected force-model-mappings to be false")
}
}
// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting.
func TestPutAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted.
func TestPutAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings, got %d", len(mappings))
}
expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"}
for _, m := range mappings {
if expected[m.From] != m.To {
t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To)
}
}
}
// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly.
func TestPatchAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PATCH failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 2 {
t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings))
}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
if found["gpt-4"] != "updated-target" {
t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"])
}
if found["new-model"] != "new-target" {
t.Errorf("new-model should map to new-target, got %q", found["new-model"])
}
}
// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others.
func TestDeleteAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["a", "c"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping remaining, got %d", len(mappings))
}
if mappings[0].From != "b" || mappings[0].To != "2" {
t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To)
}
}
// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones.
func TestDeleteAmpModelMappings_NonExistent(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
delBody := `{"value": ["non-existent-model"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 1 {
t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"]))
}
}
// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings.
func TestPutAmpModelMappings_Empty(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": []}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}
// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state.
func TestPutAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-api.example.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "https://new-api.example.com" {
t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"])
}
}
// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL.
func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-url"])
}
}
// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state.
func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-api-key-xyz"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "new-secret-api-key-xyz" {
t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"])
}
}
// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key.
func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-api-key"])
}
}
// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction.
func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["restrict-management-to-localhost"] != false {
t.Error("expected false after update")
}
}
// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting.
func TestPutAmpForceModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["force-model-mappings"] != true {
t.Error("expected true after update")
}
}
// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400.
func TestPutBoolField_EmptyObject(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code)
}
}
// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET.
func TestComplexMappingsWorkflow(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}`
req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["m1", "m3"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings))
}
expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
for from, to := range expected {
if found[from] != to {
t.Errorf("mapping %s: expected %q, got %q", from, to, found[from])
}
}
}
// TestNilHandlerGetAmpCode verifies handler works with empty config.
func TestNilHandlerGetAmpCode(t *testing.T) {
cfg := &config.Config{}
h := management.NewHandler(cfg, "", nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config.
func TestEmptyConfigGetAmpModelMappings(t *testing.T) {
cfg := &config.Config{}
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}