mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 01:06:39 +00:00
Compare commits
128 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b27a175fef | ||
|
|
8d5f89ccfd | ||
|
|
084e2666cb | ||
|
|
2eb2dbb266 | ||
|
|
e717939edb | ||
|
|
7758a86d1e | ||
|
|
9f72a875f8 | ||
|
|
94d61c7b2b | ||
|
|
f999650322 | ||
|
|
1249b07eb8 | ||
|
|
38319b0483 | ||
|
|
6b37f33d31 | ||
|
|
af543238aa | ||
|
|
2de27c560b | ||
|
|
773ed6cc64 | ||
|
|
a594338bc5 | ||
|
|
f25f419e5a | ||
|
|
1fd1ccca17 | ||
|
|
b7e382008f | ||
|
|
70d6b95097 | ||
|
|
9b202b6c1c | ||
|
|
6a66b6801a | ||
|
|
5b6d201408 | ||
|
|
5ec9b5e5a9 | ||
|
|
5db3b58717 | ||
|
|
1fa5514d56 | ||
|
|
347769b3e3 | ||
|
|
3cfe7008a2 | ||
|
|
c353606860 | ||
|
|
2ba31ecc2d | ||
|
|
da23ddb061 | ||
|
|
39b6b3b289 | ||
|
|
c600519fa4 | ||
|
|
e5312fb5a2 | ||
|
|
36380846a7 | ||
|
|
92df0cada9 | ||
|
|
96b55acff8 | ||
|
|
608cd8ee3d | ||
|
|
9f41894573 | ||
|
|
bb45fee1cf | ||
|
|
af00304b0c | ||
|
|
5c3a013cd1 | ||
|
|
ab9e9442ec | ||
|
|
6ad188921c | ||
|
|
15ed98d6a9 | ||
|
|
a283545b6b | ||
|
|
3efbd865a8 | ||
|
|
aee659fb66 | ||
|
|
5aa386d8b9 | ||
|
|
0adc0ee6aa | ||
|
|
92f13fc316 | ||
|
|
05cfa16e5f | ||
|
|
93a6e2d920 | ||
|
|
6b49580716 | ||
|
|
de77903915 | ||
|
|
56ed0d8d90 | ||
|
|
0d4f32a881 | ||
|
|
42e818ce05 | ||
|
|
2d4c54ba54 | ||
|
|
e9eb4db8bb | ||
|
|
d26ed069fa | ||
|
|
afcab5efda | ||
|
|
1770c491db | ||
|
|
a0c6cffb0d | ||
|
|
2bf9e08b31 | ||
|
|
f56bfaa689 | ||
|
|
5d716dc796 | ||
|
|
f81ff16022 | ||
|
|
6cf1d8a947 | ||
|
|
a174d015f2 | ||
|
|
9c09128e00 | ||
|
|
68cbe20664 | ||
|
|
549c0c2c5a | ||
|
|
f092801b61 | ||
|
|
15353a6b6a | ||
|
|
1b638b3629 | ||
|
|
6f5f81753d | ||
|
|
76af454034 | ||
|
|
e54d2f6b2a | ||
|
|
bfc738b76a | ||
|
|
396899a530 | ||
|
|
04f0070a80 | ||
|
|
f383840cf9 | ||
|
|
239fc4a8c4 | ||
|
|
fd29ab418a | ||
|
|
df91408919 | ||
|
|
7a628426dc | ||
|
|
aa6c7facab | ||
|
|
8ba4c7c7ed | ||
|
|
56b4d7a76e | ||
|
|
b211c3546d | ||
|
|
edc654edf9 | ||
|
|
08586334af | ||
|
|
a4804b358f | ||
|
|
7ea14479fb | ||
|
|
54af96d321 | ||
|
|
22579155c5 | ||
|
|
c04c3832a4 | ||
|
|
5ffbd54755 | ||
|
|
5d12d4ce33 | ||
|
|
b73e53d6c4 | ||
|
|
b06463c6d9 | ||
|
|
5eb8453e91 | ||
|
|
f77c22e6ff | ||
|
|
df83ba877f | ||
|
|
9583f6b1c5 | ||
|
|
02d8a1cfec | ||
|
|
92f033dec0 | ||
|
|
0ebabf5152 | ||
|
|
4b01ecba2e | ||
|
|
d7564173dd | ||
|
|
f241124599 | ||
|
|
c44c46dd80 | ||
|
|
aa810ee719 | ||
|
|
412148af0e | ||
|
|
5d2baf6058 | ||
|
|
d28258501a | ||
|
|
55cd31fb96 | ||
|
|
d138df07bf | ||
|
|
c5df8e7897 | ||
|
|
d4d529833d | ||
|
|
caa48e7c6f | ||
|
|
acdfb3bceb | ||
|
|
89d68962b1 | ||
|
|
691cdb6bdf | ||
|
|
8064cba288 | ||
|
|
361443db10 | ||
|
|
d6352dd4d4 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
# Binaries
|
# Binaries
|
||||||
cli-proxy-api
|
cli-proxy-api
|
||||||
|
cliproxy
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
@@ -31,6 +32,7 @@ GEMINI.md
|
|||||||
.vscode/*
|
.vscode/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.serena/*
|
.serena/*
|
||||||
|
.mcp/cache/
|
||||||
|
|
||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ The Plus release stays in lockstep with the mainline features.
|
|||||||
|
|
||||||
## Differences from the Mainline
|
## Differences from the Mainline
|
||||||
|
|
||||||
- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||||
|
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@
|
|||||||
|
|
||||||
## 与主线版本版本差异
|
## 与主线版本版本差异
|
||||||
|
|
||||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||||
|
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)提供
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,19 @@ func init() {
|
|||||||
buildinfo.BuildDate = BuildDate
|
buildinfo.BuildDate = BuildDate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
|
||||||
|
// Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can explicitly override with --incognito or --no-incognito flags.
|
||||||
|
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
|
||||||
|
if useIncognito {
|
||||||
|
cfg.IncognitoBrowser = true
|
||||||
|
} else if noIncognito {
|
||||||
|
cfg.IncognitoBrowser = false
|
||||||
|
} else {
|
||||||
|
cfg.IncognitoBrowser = true // Kiro default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// main is the entry point of the application.
|
// main is the entry point of the application.
|
||||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||||
// service based on the provided flags (login, codex-login, or server mode).
|
// service based on the provided flags (login, codex-login, or server mode).
|
||||||
@@ -62,11 +75,17 @@ func main() {
|
|||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
|
var kiroLogin bool
|
||||||
|
var kiroGoogleLogin bool
|
||||||
|
var kiroAWSLogin bool
|
||||||
|
var kiroImport bool
|
||||||
var githubCopilotLogin bool
|
var githubCopilotLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
|
var noIncognito bool
|
||||||
|
var useIncognito bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
@@ -76,7 +95,13 @@ func main() {
|
|||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
|
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
@@ -141,7 +166,8 @@ func main() {
|
|||||||
|
|
||||||
wd, err := os.Getwd()
|
wd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to get working directory: %v", err)
|
log.Errorf("failed to get working directory: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load environment variables from .env if present.
|
// Load environment variables from .env if present.
|
||||||
@@ -235,13 +261,15 @@ func main() {
|
|||||||
})
|
})
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize postgres token store: %v", err)
|
log.Errorf("failed to initialize postgres token store: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||||
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||||
cancel()
|
cancel()
|
||||||
log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
configFilePath = pgStoreInst.ConfigPath()
|
configFilePath = pgStoreInst.ConfigPath()
|
||||||
@@ -264,7 +292,8 @@ func main() {
|
|||||||
if strings.Contains(resolvedEndpoint, "://") {
|
if strings.Contains(resolvedEndpoint, "://") {
|
||||||
parsed, errParse := url.Parse(resolvedEndpoint)
|
parsed, errParse := url.Parse(resolvedEndpoint)
|
||||||
if errParse != nil {
|
if errParse != nil {
|
||||||
log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
switch strings.ToLower(parsed.Scheme) {
|
switch strings.ToLower(parsed.Scheme) {
|
||||||
case "http":
|
case "http":
|
||||||
@@ -272,10 +301,12 @@ func main() {
|
|||||||
case "https":
|
case "https":
|
||||||
useSSL = true
|
useSSL = true
|
||||||
default:
|
default:
|
||||||
log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if parsed.Host == "" {
|
if parsed.Host == "" {
|
||||||
log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
resolvedEndpoint = parsed.Host
|
resolvedEndpoint = parsed.Host
|
||||||
if parsed.Path != "" && parsed.Path != "/" {
|
if parsed.Path != "" && parsed.Path != "/" {
|
||||||
@@ -294,13 +325,15 @@ func main() {
|
|||||||
}
|
}
|
||||||
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
|
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize object token store: %v", err)
|
log.Errorf("failed to initialize object token store: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||||
cancel()
|
cancel()
|
||||||
log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap)
|
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
configFilePath = objectStoreInst.ConfigPath()
|
configFilePath = objectStoreInst.ConfigPath()
|
||||||
@@ -325,7 +358,8 @@ func main() {
|
|||||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
||||||
gitStoreInst.SetBaseDir(authDir)
|
gitStoreInst.SetBaseDir(authDir)
|
||||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||||
log.Fatalf("failed to prepare git token store: %v", errRepo)
|
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
configFilePath = gitStoreInst.ConfigPath()
|
configFilePath = gitStoreInst.ConfigPath()
|
||||||
if configFilePath == "" {
|
if configFilePath == "" {
|
||||||
@@ -334,17 +368,21 @@ func main() {
|
|||||||
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
|
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
|
||||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||||
if _, errExample := os.Stat(examplePath); errExample != nil {
|
if _, errExample := os.Stat(examplePath); errExample != nil {
|
||||||
log.Fatalf("failed to find template config file: %v", errExample)
|
log.Errorf("failed to find template config file: %v", errExample)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
|
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
|
||||||
log.Fatalf("failed to bootstrap git-backed config: %v", errCopy)
|
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
|
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
|
||||||
log.Fatalf("failed to commit initial git-backed config: %v", errCommit)
|
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
log.Infof("git-backed config initialized from template: %s", configFilePath)
|
log.Infof("git-backed config initialized from template: %s", configFilePath)
|
||||||
} else if statErr != nil {
|
} else if statErr != nil {
|
||||||
log.Fatalf("failed to inspect git-backed config: %v", statErr)
|
log.Errorf("failed to inspect git-backed config: %v", statErr)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -357,13 +395,15 @@ func main() {
|
|||||||
} else {
|
} else {
|
||||||
wd, err = os.Getwd()
|
wd, err = os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to get working directory: %v", err)
|
log.Errorf("failed to get working directory: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
configFilePath = filepath.Join(wd, "config.yaml")
|
configFilePath = filepath.Join(wd, "config.yaml")
|
||||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to load config: %v", err)
|
log.Errorf("failed to load config: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &config.Config{}
|
cfg = &config.Config{}
|
||||||
@@ -393,7 +433,8 @@ func main() {
|
|||||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
|
|
||||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||||
log.Fatalf("failed to configure log output: %v", err)
|
log.Errorf("failed to configure log output: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||||
@@ -402,7 +443,8 @@ func main() {
|
|||||||
util.SetLogLevel(cfg)
|
util.SetLogLevel(cfg)
|
||||||
|
|
||||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||||
log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir)
|
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||||
|
return
|
||||||
} else {
|
} else {
|
||||||
cfg.AuthDir = resolvedAuthDir
|
cfg.AuthDir = resolvedAuthDir
|
||||||
}
|
}
|
||||||
@@ -453,6 +495,26 @@ func main() {
|
|||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if kiroLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
|
// and don't share config with StartService (which is in the else branch)
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroLogin(cfg, options)
|
||||||
|
} else if kiroGoogleLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroGoogleLogin(cfg, options)
|
||||||
|
} else if kiroAWSLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroAWSLogin(cfg, options)
|
||||||
|
} else if kiroImport {
|
||||||
|
cmd.DoKiroImport(cfg, options)
|
||||||
} else {
|
} else {
|
||||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||||
if isCloudDeploy && !configFileExists {
|
if isCloudDeploy && !configFileExists {
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||||
|
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||||
|
host: ""
|
||||||
|
|
||||||
# Server port
|
# Server port
|
||||||
port: 8317
|
port: 8317
|
||||||
|
|
||||||
@@ -32,6 +36,11 @@ api-keys:
|
|||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
|
|
||||||
|
# Open OAuth URLs in incognito/private browser mode.
|
||||||
|
# Useful when you want to login with a different account without logging out from your current session.
|
||||||
|
# Default: false (but Kiro auth defaults to true for multi-account support)
|
||||||
|
incognito-browser: true
|
||||||
|
|
||||||
# When true, write application logs to rotating files instead of stdout
|
# When true, write application logs to rotating files instead of stdout
|
||||||
logging-to-file: false
|
logging-to-file: false
|
||||||
|
|
||||||
@@ -99,6 +108,16 @@ ws-auth: false
|
|||||||
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||||
|
|
||||||
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
|
#kiro:
|
||||||
|
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
||||||
|
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
||||||
|
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
||||||
|
# refresh-token: "aorAAAAA..."
|
||||||
|
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||||
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||||
|
|
||||||
# OpenAI compatibility providers
|
# OpenAI compatibility providers
|
||||||
# openai-compatibility:
|
# openai-compatibility:
|
||||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||||
@@ -134,6 +153,8 @@ ws-auth: false
|
|||||||
# upstream-api-key: ""
|
# upstream-api-key: ""
|
||||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||||
# restrict-management-to-localhost: true
|
# restrict-management-to-localhost: true
|
||||||
|
# # Force model mappings to run before checking local API keys (default: false)
|
||||||
|
# force-model-mappings: false
|
||||||
# # Amp Model Mappings
|
# # Amp Model Mappings
|
||||||
# # Route unavailable Amp models to alternative models available in your local proxy.
|
# # Route unavailable Amp models to alternative models available in your local proxy.
|
||||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -13,14 +13,15 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/klauspost/compress v1.17.4
|
github.com/klauspost/compress v1.17.4
|
||||||
github.com/minio/minio-go/v7 v7.0.66
|
github.com/minio/minio-go/v7 v7.0.66
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/tiktoken-go/tokenizer v0.7.0
|
github.com/tiktoken-go/tokenizer v0.7.0
|
||||||
golang.org/x/crypto v0.43.0
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/net v0.46.0
|
golang.org/x/net v0.46.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
|
golang.org/x/term v0.36.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|||||||
5
go.sum
5
go.sum
@@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
@@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
|||||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
|
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
@@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl
|
|||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
|||||||
@@ -36,9 +36,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
oauthStatus = make(map[string]string)
|
oauthStatus = make(map[string]string)
|
||||||
|
oauthStatusMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getOAuthStatus safely retrieves an OAuth status
|
||||||
|
func getOAuthStatus(key string) (string, bool) {
|
||||||
|
oauthStatusMutex.RLock()
|
||||||
|
defer oauthStatusMutex.RUnlock()
|
||||||
|
status, ok := oauthStatus[key]
|
||||||
|
return status, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setOAuthStatus safely sets an OAuth status
|
||||||
|
func setOAuthStatus(key string, status string) {
|
||||||
|
oauthStatusMutex.Lock()
|
||||||
|
defer oauthStatusMutex.Unlock()
|
||||||
|
oauthStatus[key] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteOAuthStatus safely deletes an OAuth status
|
||||||
|
func deleteOAuthStatus(key string) {
|
||||||
|
oauthStatusMutex.Lock()
|
||||||
|
defer oauthStatusMutex.Unlock()
|
||||||
|
delete(oauthStatus, key)
|
||||||
|
}
|
||||||
|
|
||||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -713,14 +736,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
// Generate PKCE codes
|
// Generate PKCE codes
|
||||||
pkceCodes, err := claude.GeneratePKCECodes()
|
pkceCodes, err := claude.GeneratePKCECodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate PKCE codes: %v", err)
|
log.Errorf("Failed to generate PKCE codes: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random state parameter
|
// Generate random state parameter
|
||||||
state, err := misc.GenerateRandomState()
|
state, err := misc.GenerateRandomState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate state parameter: %v", err)
|
log.Errorf("Failed to generate state parameter: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -730,7 +755,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
// Generate authorization URL (then override redirect_uri to reuse server port)
|
// Generate authorization URL (then override redirect_uri to reuse server port)
|
||||||
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
|
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
log.Errorf("Failed to generate authorization URL: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -760,7 +786,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
deadline := time.Now().Add(timeout)
|
deadline := time.Now().Add(timeout)
|
||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
||||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
}
|
}
|
||||||
data, errRead := os.ReadFile(path)
|
data, errRead := os.ReadFile(path)
|
||||||
@@ -785,13 +811,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
if errStr := resultMap["error"]; errStr != "" {
|
if errStr := resultMap["error"]; errStr != "" {
|
||||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||||
oauthStatus[state] = "Bad request"
|
setOAuthStatus(state, "Bad request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resultMap["state"] != state {
|
if resultMap["state"] != state {
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||||
oauthStatus[state] = "State code error"
|
setOAuthStatus(state, "State code error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,7 +850,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -835,7 +861,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var tResp struct {
|
var tResp struct {
|
||||||
@@ -848,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
bundle := &claude.ClaudeAuthBundle{
|
bundle := &claude.ClaudeAuthBundle{
|
||||||
@@ -872,8 +898,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -882,10 +908,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use Claude services through this CLI")
|
fmt.Println("You can now use Claude services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -944,7 +970,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
log.Error("oauth flow timed out")
|
log.Error("oauth flow timed out")
|
||||||
oauthStatus[state] = "OAuth flow timed out"
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
@@ -953,13 +979,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
_ = os.Remove(waitFile)
|
_ = os.Remove(waitFile)
|
||||||
if errStr := m["error"]; errStr != "" {
|
if errStr := m["error"]; errStr != "" {
|
||||||
log.Errorf("Authentication failed: %s", errStr)
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authCode = m["code"]
|
authCode = m["code"]
|
||||||
if authCode == "" {
|
if authCode == "" {
|
||||||
log.Errorf("Authentication failed: code not found")
|
log.Errorf("Authentication failed: code not found")
|
||||||
oauthStatus[state] = "Authentication failed: code not found"
|
setOAuthStatus(state, "Authentication failed: code not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -971,7 +997,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
token, err := conf.Exchange(ctx, authCode)
|
token, err := conf.Exchange(ctx, authCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to exchange token: %v", err)
|
log.Errorf("Failed to exchange token: %v", err)
|
||||||
oauthStatus[state] = "Failed to exchange token"
|
setOAuthStatus(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -982,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if errNewRequest != nil {
|
if errNewRequest != nil {
|
||||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||||
oauthStatus[state] = "Could not get user info"
|
setOAuthStatus(state, "Could not get user info")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -991,7 +1017,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
resp, errDo := authHTTPClient.Do(req)
|
resp, errDo := authHTTPClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
log.Errorf("Failed to execute request: %v", errDo)
|
log.Errorf("Failed to execute request: %v", errDo)
|
||||||
oauthStatus[state] = "Failed to execute request"
|
setOAuthStatus(state, "Failed to execute request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1003,7 +1029,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1012,7 +1038,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
fmt.Printf("Authenticated user email: %s\n", email)
|
fmt.Printf("Authenticated user email: %s\n", email)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Failed to get user email from token")
|
fmt.Println("Failed to get user email from token")
|
||||||
oauthStatus[state] = "Failed to get user email from token"
|
setOAuthStatus(state, "Failed to get user email from token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||||
@@ -1020,7 +1046,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
jsonData, _ := json.Marshal(token)
|
jsonData, _ := json.Marshal(token)
|
||||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||||
oauthStatus[state] = "Failed to unmarshal token"
|
setOAuthStatus(state, "Failed to unmarshal token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1045,8 +1071,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
gemAuth := geminiAuth.NewGeminiAuth()
|
gemAuth := geminiAuth.NewGeminiAuth()
|
||||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||||
oauthStatus[state] = "Failed to get authenticated client"
|
setOAuthStatus(state, "Failed to get authenticated client")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Println("Authentication successful.")
|
fmt.Println("Authentication successful.")
|
||||||
@@ -1056,12 +1082,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||||
if errAll != nil {
|
if errAll != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.ProjectID = strings.Join(projects, ",")
|
ts.ProjectID = strings.Join(projects, ",")
|
||||||
@@ -1069,26 +1095,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||||
log.Error("Onboarding did not return a project ID")
|
log.Error("Onboarding did not return a project ID")
|
||||||
oauthStatus[state] = "Failed to resolve project ID"
|
setOAuthStatus(state, "Failed to resolve project ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||||
if errCheck != nil {
|
if errCheck != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.Checked = isChecked
|
ts.Checked = isChecked
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Error("Cloud AI API is not enabled for the selected project")
|
log.Error("Cloud AI API is not enabled for the selected project")
|
||||||
oauthStatus[state] = "Cloud AI API not enabled"
|
setOAuthStatus(state, "Cloud AI API not enabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1110,16 +1136,16 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("Failed to save token to file: %v", errSave)
|
log.Errorf("Failed to save token to file: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save token to file"
|
setOAuthStatus(state, "Failed to save token to file")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1131,14 +1157,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
// Generate PKCE codes
|
// Generate PKCE codes
|
||||||
pkceCodes, err := codex.GeneratePKCECodes()
|
pkceCodes, err := codex.GeneratePKCECodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate PKCE codes: %v", err)
|
log.Errorf("Failed to generate PKCE codes: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random state parameter
|
// Generate random state parameter
|
||||||
state, err := misc.GenerateRandomState()
|
state, err := misc.GenerateRandomState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate state parameter: %v", err)
|
log.Errorf("Failed to generate state parameter: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1148,7 +1176,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
// Generate authorization URL
|
// Generate authorization URL
|
||||||
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
|
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
log.Errorf("Failed to generate authorization URL: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1180,7 +1209,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
@@ -1190,12 +1219,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
if errStr := m["error"]; errStr != "" {
|
if errStr := m["error"]; errStr != "" {
|
||||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||||
oauthStatus[state] = "Bad Request"
|
setOAuthStatus(state, "Bad Request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if m["state"] != state {
|
if m["state"] != state {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||||
oauthStatus[state] = "State code error"
|
setOAuthStatus(state, "State code error")
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1226,14 +1255,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1244,7 +1273,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
}
|
}
|
||||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1282,8 +1311,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
@@ -1291,10 +1320,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use Codex services through this CLI")
|
fmt.Println("You can now use Codex services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1318,7 +1347,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
state, errState := misc.GenerateRandomState()
|
state, errState := misc.GenerateRandomState()
|
||||||
if errState != nil {
|
if errState != nil {
|
||||||
log.Fatalf("Failed to generate state parameter: %v", errState)
|
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1360,7 +1390,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
log.Error("oauth flow timed out")
|
log.Error("oauth flow timed out")
|
||||||
oauthStatus[state] = "OAuth flow timed out"
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||||
@@ -1369,18 +1399,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
_ = os.Remove(waitFile)
|
_ = os.Remove(waitFile)
|
||||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||||
log.Errorf("Authentication failed: %s", errStr)
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||||
log.Errorf("Authentication failed: state mismatch")
|
log.Errorf("Authentication failed: state mismatch")
|
||||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
setOAuthStatus(state, "Authentication failed: state mismatch")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authCode = strings.TrimSpace(payload["code"])
|
authCode = strings.TrimSpace(payload["code"])
|
||||||
if authCode == "" {
|
if authCode == "" {
|
||||||
log.Error("Authentication failed: code not found")
|
log.Error("Authentication failed: code not found")
|
||||||
oauthStatus[state] = "Authentication failed: code not found"
|
setOAuthStatus(state, "Authentication failed: code not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -1399,7 +1429,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||||
if errNewRequest != nil {
|
if errNewRequest != nil {
|
||||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||||
oauthStatus[state] = "Failed to build token request"
|
setOAuthStatus(state, "Failed to build token request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -1407,7 +1437,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
log.Errorf("Failed to execute token request: %v", errDo)
|
log.Errorf("Failed to execute token request: %v", errDo)
|
||||||
oauthStatus[state] = "Failed to exchange token"
|
setOAuthStatus(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1419,7 +1449,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1431,7 +1461,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1440,7 +1470,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if errInfoReq != nil {
|
if errInfoReq != nil {
|
||||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||||
oauthStatus[state] = "Failed to build user info request"
|
setOAuthStatus(state, "Failed to build user info request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||||
@@ -1448,7 +1478,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
infoResp, errInfo := httpClient.Do(infoReq)
|
infoResp, errInfo := httpClient.Do(infoReq)
|
||||||
if errInfo != nil {
|
if errInfo != nil {
|
||||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||||
oauthStatus[state] = "Failed to execute user info request"
|
setOAuthStatus(state, "Failed to execute user info request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1467,11 +1497,22 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
projectID := ""
|
||||||
|
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
||||||
|
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||||
|
if errProject != nil {
|
||||||
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
|
} else {
|
||||||
|
projectID = fetchedProjectID
|
||||||
|
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"type": "antigravity",
|
"type": "antigravity",
|
||||||
@@ -1484,6 +1525,9 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
if email != "" {
|
if email != "" {
|
||||||
metadata["email"] = email
|
metadata["email"] = email
|
||||||
}
|
}
|
||||||
|
if projectID != "" {
|
||||||
|
metadata["project_id"] = projectID
|
||||||
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := sanitizeAntigravityFileName(email)
|
||||||
label := strings.TrimSpace(email)
|
label := strings.TrimSpace(email)
|
||||||
@@ -1500,17 +1544,20 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("Failed to save token to file: %v", errSave)
|
log.Errorf("Failed to save token to file: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save token to file"
|
setOAuthStatus(state, "Failed to save token to file")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if projectID != "" {
|
||||||
|
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||||
|
}
|
||||||
fmt.Println("You can now use Antigravity services through this CLI")
|
fmt.Println("You can now use Antigravity services through this CLI")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1526,7 +1573,8 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
// Generate authorization URL
|
// Generate authorization URL
|
||||||
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
|
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
log.Errorf("Failed to generate authorization URL: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authURL := deviceFlow.VerificationURIComplete
|
authURL := deviceFlow.VerificationURIComplete
|
||||||
@@ -1535,7 +1583,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
fmt.Println("Waiting for authentication...")
|
fmt.Println("Waiting for authentication...")
|
||||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||||
if errPollForToken != nil {
|
if errPollForToken != nil {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1553,17 +1601,17 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
fmt.Println("You can now use Qwen services through this CLI")
|
fmt.Println("You can now use Qwen services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1602,7 +1650,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
var resultMap map[string]string
|
var resultMap map[string]string
|
||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1615,26 +1663,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: state mismatch")
|
fmt.Println("Authentication failed: state mismatch")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := strings.TrimSpace(resultMap["code"])
|
code := strings.TrimSpace(resultMap["code"])
|
||||||
if code == "" {
|
if code == "" {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: code missing")
|
fmt.Println("Authentication failed: code missing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||||
if errExchange != nil {
|
if errExchange != nil {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1656,8 +1704,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
|
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
log.Fatalf("Failed to save authentication tokens: %v", errSave)
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1666,10 +1714,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use iFlow services through this CLI")
|
fmt.Println("You can now use iFlow services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2086,6 +2134,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
return false, fmt.Errorf("project activation required: %s", errMessage)
|
return false, fmt.Errorf("project activation required: %s", errMessage)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -2093,7 +2142,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
|
|
||||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if err, ok := oauthStatus[state]; ok {
|
if err, ok := getOAuthStatus(state); ok {
|
||||||
if err != "" {
|
if err != "" {
|
||||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||||
} else {
|
} else {
|
||||||
@@ -2103,5 +2152,5 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
}
|
}
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,28 @@
|
|||||||
package management
|
package management
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
|
||||||
|
latestReleaseUserAgent = "CLIProxyAPIPlus"
|
||||||
|
)
|
||||||
|
|
||||||
func (h *Handler) GetConfig(c *gin.Context) {
|
func (h *Handler) GetConfig(c *gin.Context) {
|
||||||
if h == nil || h.cfg == nil {
|
if h == nil || h.cfg == nil {
|
||||||
c.JSON(200, gin.H{})
|
c.JSON(200, gin.H{})
|
||||||
@@ -20,6 +32,66 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
|||||||
c.JSON(200, &cfgCopy)
|
c.JSON(200, &cfgCopy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type releaseInfo struct {
|
||||||
|
TagName string `json:"tag_name"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
|
||||||
|
func (h *Handler) GetLatestVersion(c *gin.Context) {
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
proxyURL := ""
|
||||||
|
if h != nil && h.cfg != nil {
|
||||||
|
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
if proxyURL != "" {
|
||||||
|
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
|
||||||
|
util.SetProxy(sdkCfg, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/vnd.github+json")
|
||||||
|
req.Header.Set("User-Agent", latestReleaseUserAgent)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.WithError(errClose).Debug("failed to close latest version response body")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var info releaseInfo
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
version := strings.TrimSpace(info.TagName)
|
||||||
|
if version == "" {
|
||||||
|
version = strings.TrimSpace(info.Name)
|
||||||
|
}
|
||||||
|
if version == "" {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"latest-version": version})
|
||||||
|
}
|
||||||
|
|
||||||
func WriteConfig(path string, data []byte) error {
|
func WriteConfig(path string, data []byte) error {
|
||||||
data = config.NormalizeCommentIndentation(data)
|
data = config.NormalizeCommentIndentation(data)
|
||||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
|||||||
@@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
|||||||
}
|
}
|
||||||
entry.Models = normalized
|
entry.Models = normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAmpCode returns the complete ampcode configuration.
|
||||||
|
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmpUpstreamURL returns the ampcode upstream URL.
|
||||||
|
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"upstream-url": ""})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpUpstreamURL updates the ampcode upstream URL.
|
||||||
|
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
|
||||||
|
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
|
||||||
|
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
|
||||||
|
h.cfg.AmpCode.UpstreamURL = ""
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
|
||||||
|
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"upstream-api-key": ""})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
|
||||||
|
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
|
||||||
|
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
|
||||||
|
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKey = ""
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
|
||||||
|
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
|
||||||
|
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
|
||||||
|
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmpModelMappings returns the ampcode model mappings.
|
||||||
|
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpModelMappings replaces all ampcode model mappings.
|
||||||
|
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpModelMapping `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.cfg.AmpCode.ModelMappings = body.Value
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchAmpModelMappings adds or updates model mappings.
|
||||||
|
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpModelMapping `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := make(map[string]int)
|
||||||
|
for i, m := range h.cfg.AmpCode.ModelMappings {
|
||||||
|
existing[strings.TrimSpace(m.From)] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newMapping := range body.Value {
|
||||||
|
from := strings.TrimSpace(newMapping.From)
|
||||||
|
if idx, ok := existing[from]; ok {
|
||||||
|
h.cfg.AmpCode.ModelMappings[idx] = newMapping
|
||||||
|
} else {
|
||||||
|
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
|
||||||
|
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAmpModelMappings removes specified model mappings by "from" field.
|
||||||
|
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []string `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
|
||||||
|
h.cfg.AmpCode.ModelMappings = nil
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toRemove := make(map[string]bool)
|
||||||
|
for _, from := range body.Value {
|
||||||
|
toRemove[strings.TrimSpace(from)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
|
||||||
|
for _, m := range h.cfg.AmpCode.ModelMappings {
|
||||||
|
if !toRemove[strings.TrimSpace(m.From)] {
|
||||||
|
newMappings = append(newMappings, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.cfg.AmpCode.ModelMappings = newMappings
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmpForceModelMappings returns whether model mappings are forced.
|
||||||
|
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"force-model-mappings": false})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpForceModelMappings updates the force model mappings setting.
|
||||||
|
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||||
|
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||||
|
}
|
||||||
|
|||||||
@@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
|
|||||||
Value *bool `json:"value"`
|
Value *bool `json:"value"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||||
var m map[string]any
|
|
||||||
if err2 := c.ShouldBindJSON(&m); err2 == nil {
|
|
||||||
for _, v := range m {
|
|
||||||
if b, ok := v.(bool); ok {
|
|
||||||
set(b)
|
|
||||||
h.persist(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,5 +112,10 @@ func shouldLogRequest(path string) bool {
|
|||||||
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(path, "/api") {
|
||||||
|
return strings.HasPrefix(path, "/api/provider")
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
w.streamDone = nil
|
w.streamDone = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write API Request and Response to the streaming log before closing
|
||||||
if w.streamWriter != nil {
|
if w.streamWriter != nil {
|
||||||
|
apiRequest := w.extractAPIRequest(c)
|
||||||
|
if len(apiRequest) > 0 {
|
||||||
|
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||||
|
}
|
||||||
|
apiResponse := w.extractAPIResponse(c)
|
||||||
|
if len(apiResponse) > 0 {
|
||||||
|
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||||
|
}
|
||||||
if err := w.streamWriter.Close(); err != nil {
|
if err := w.streamWriter.Close(); err != nil {
|
||||||
w.streamWriter = nil
|
w.streamWriter = nil
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -100,6 +100,16 @@ func (m *AmpModule) Name() string {
|
|||||||
return "amp-routing"
|
return "amp-routing"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceModelMappings returns whether model mappings should take precedence over local API keys
|
||||||
|
func (m *AmpModule) forceModelMappings() bool {
|
||||||
|
m.configMu.RLock()
|
||||||
|
defer m.configMu.RUnlock()
|
||||||
|
if m.lastConfig == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return m.lastConfig.ForceModelMappings
|
||||||
|
}
|
||||||
|
|
||||||
// Register sets up Amp routes if configured.
|
// Register sets up Amp routes if configured.
|
||||||
// This implements the RouteModuleV2 interface with Context.
|
// This implements the RouteModuleV2 interface with Context.
|
||||||
// Routes are registered only once via sync.Once for idempotent behavior.
|
// Routes are registered only once via sync.Once for idempotent behavior.
|
||||||
@@ -126,6 +136,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
// Always register provider aliases - these work without an upstream
|
// Always register provider aliases - these work without an upstream
|
||||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||||
|
|
||||||
|
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||||
|
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||||
|
|
||||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||||
if upstreamURL == "" {
|
if upstreamURL == "" {
|
||||||
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
|
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
|
||||||
@@ -134,27 +147,11 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create secret source with precedence: config > env > file
|
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
|
||||||
// Cache secrets for 5 minutes to reduce file I/O
|
|
||||||
if m.secretSource == nil {
|
|
||||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create reverse proxy with gzip handling via ModifyResponse
|
|
||||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
|
||||||
if err != nil {
|
|
||||||
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.setProxy(proxy)
|
|
||||||
m.enabled = true
|
|
||||||
|
|
||||||
// Register management proxy routes (requires upstream)
|
|
||||||
// Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support
|
|
||||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
|
||||||
|
|
||||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
|
||||||
log.Debug("amp provider alias routes registered")
|
log.Debug("amp provider alias routes registered")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -188,18 +185,30 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
oldSettings := m.lastConfig
|
oldSettings := m.lastConfig
|
||||||
m.configMu.RUnlock()
|
m.configMu.RUnlock()
|
||||||
|
|
||||||
// Track what changed for logging
|
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||||
var changes []string
|
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||||
|
if !newSettings.RestrictManagementToLocalhost {
|
||||||
|
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||||
|
oldUpstreamURL := ""
|
||||||
|
if oldSettings != nil {
|
||||||
|
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.enabled && newUpstreamURL != "" {
|
||||||
|
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
|
||||||
|
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check model mappings change
|
// Check model mappings change
|
||||||
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
|
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
|
||||||
if modelMappingsChanged {
|
if modelMappingsChanged {
|
||||||
if m.modelMapper != nil {
|
if m.modelMapper != nil {
|
||||||
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
|
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
|
||||||
changes = append(changes, "model-mappings")
|
|
||||||
if m.enabled {
|
|
||||||
log.Infof("amp config partial reload: model mappings updated (%d entries)", len(newSettings.ModelMappings))
|
|
||||||
}
|
|
||||||
} else if m.enabled {
|
} else if m.enabled {
|
||||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||||
}
|
}
|
||||||
@@ -207,25 +216,16 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
|
|
||||||
if m.enabled {
|
if m.enabled {
|
||||||
// Check upstream URL change - now supports hot-reload
|
// Check upstream URL change - now supports hot-reload
|
||||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
|
||||||
oldUpstreamURL := ""
|
|
||||||
if oldSettings != nil {
|
|
||||||
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||||
log.Warn("amp upstream URL removed from config, proxy has been disabled")
|
|
||||||
m.setProxy(nil)
|
m.setProxy(nil)
|
||||||
changes = append(changes, "upstream-url(disabled)")
|
m.enabled = false
|
||||||
} else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
||||||
// Recreate proxy with new URL
|
// Recreate proxy with new URL
|
||||||
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
||||||
} else {
|
} else {
|
||||||
m.setProxy(proxy)
|
m.setProxy(proxy)
|
||||||
changes = append(changes, "upstream-url")
|
|
||||||
log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,22 +236,10 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||||
ms.InvalidateCache()
|
ms.InvalidateCache()
|
||||||
changes = append(changes, "upstream-api-key")
|
|
||||||
log.Debug("amp config partial reload: secret cache invalidated")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check restrict-management-to-localhost change - now supports hot-reload
|
|
||||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
|
||||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
|
||||||
changes = append(changes, "restrict-management-to-localhost")
|
|
||||||
if newSettings.RestrictManagementToLocalhost {
|
|
||||||
log.Infof("amp config partial reload: management routes now restricted to localhost")
|
|
||||||
} else {
|
|
||||||
log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store current config for next comparison
|
// Store current config for next comparison
|
||||||
@@ -260,13 +248,26 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
m.lastConfig = &settingsCopy
|
m.lastConfig = &settingsCopy
|
||||||
m.configMu.Unlock()
|
m.configMu.Unlock()
|
||||||
|
|
||||||
// Log summary if any changes detected
|
return nil
|
||||||
if len(changes) > 0 {
|
}
|
||||||
log.Debugf("amp config partial reload completed: %v", changes)
|
|
||||||
} else {
|
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||||
log.Debug("amp config checked: no changes detected")
|
if m.secretSource == nil {
|
||||||
|
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||||
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
|
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||||
|
ms.InvalidateCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.setProxy(proxy)
|
||||||
|
m.enabled = true
|
||||||
|
|
||||||
|
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,6 +10,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AmpRouteType represents the type of routing decision made for an Amp request
|
// AmpRouteType represents the type of routing decision made for an Amp request
|
||||||
@@ -27,6 +28,9 @@ const (
|
|||||||
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||||
|
const MappedModelContextKey = "mapped_model"
|
||||||
|
|
||||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||||
fields := log.Fields{
|
fields := log.Fields{
|
||||||
@@ -48,48 +52,54 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
|||||||
case RouteTypeLocalProvider:
|
case RouteTypeLocalProvider:
|
||||||
fields["cost"] = "free"
|
fields["cost"] = "free"
|
||||||
fields["source"] = "local_oauth"
|
fields["source"] = "local_oauth"
|
||||||
log.WithFields(fields).Infof("[amp] using local provider for model: %s", requestedModel)
|
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
||||||
|
|
||||||
case RouteTypeModelMapping:
|
case RouteTypeModelMapping:
|
||||||
fields["cost"] = "free"
|
fields["cost"] = "free"
|
||||||
fields["source"] = "local_oauth"
|
fields["source"] = "local_oauth"
|
||||||
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
||||||
log.WithFields(fields).Infof("[amp] model mapped: %s -> %s", requestedModel, resolvedModel)
|
// model mapping already logged in mapper; avoid duplicate here
|
||||||
|
|
||||||
case RouteTypeAmpCredits:
|
case RouteTypeAmpCredits:
|
||||||
fields["cost"] = "amp_credits"
|
fields["cost"] = "amp_credits"
|
||||||
fields["source"] = "ampcode.com"
|
fields["source"] = "ampcode.com"
|
||||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||||
log.WithFields(fields).Warnf("[amp] forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||||
|
|
||||||
case RouteTypeNoProvider:
|
case RouteTypeNoProvider:
|
||||||
fields["cost"] = "none"
|
fields["cost"] = "none"
|
||||||
fields["source"] = "error"
|
fields["source"] = "error"
|
||||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||||
log.WithFields(fields).Warnf("[amp] no provider available for model_id: %s", requestedModel)
|
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||||
// when the model's provider is not available in CLIProxyAPI
|
// when the model's provider is not available in CLIProxyAPI
|
||||||
type FallbackHandler struct {
|
type FallbackHandler struct {
|
||||||
getProxy func() *httputil.ReverseProxy
|
getProxy func() *httputil.ReverseProxy
|
||||||
modelMapper ModelMapper
|
modelMapper ModelMapper
|
||||||
|
forceModelMappings func() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFallbackHandler creates a new fallback handler wrapper
|
// NewFallbackHandler creates a new fallback handler wrapper
|
||||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||||
return &FallbackHandler{
|
return &FallbackHandler{
|
||||||
getProxy: getProxy,
|
getProxy: getProxy,
|
||||||
|
forceModelMappings: func() bool { return false },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||||
|
if forceModelMappings == nil {
|
||||||
|
forceModelMappings = func() bool { return false }
|
||||||
|
}
|
||||||
return &FallbackHandler{
|
return &FallbackHandler{
|
||||||
getProxy: getProxy,
|
getProxy: getProxy,
|
||||||
modelMapper: mapper,
|
modelMapper: mapper,
|
||||||
|
forceModelMappings: forceModelMappings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,32 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// Normalize model (handles Gemini thinking suffixes)
|
// Normalize model (handles Gemini thinking suffixes)
|
||||||
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
|
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
|
||||||
|
|
||||||
// Check if we have providers for this model
|
|
||||||
providers := util.GetProviderName(normalizedModel)
|
|
||||||
|
|
||||||
// Track resolved model for logging (may change if mapping is applied)
|
// Track resolved model for logging (may change if mapping is applied)
|
||||||
resolvedModel := normalizedModel
|
resolvedModel := normalizedModel
|
||||||
usedMapping := false
|
usedMapping := false
|
||||||
|
var providers []string
|
||||||
|
|
||||||
if len(providers) == 0 {
|
// Check if model mappings should be forced ahead of local API keys
|
||||||
// No providers configured - check if we have a model mapping
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||||
|
|
||||||
|
if forceMappings {
|
||||||
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||||
|
// This allows users to route Amp requests to their preferred OAuth providers
|
||||||
if fh.modelMapper != nil {
|
if fh.modelMapper != nil {
|
||||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||||
// Mapping found - rewrite the model in request body
|
// Mapping found - check if we have a provider for the mapped model
|
||||||
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
|
mappedProviders := util.GetProviderName(mappedModel)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
if len(mappedProviders) > 0 {
|
||||||
resolvedModel = mappedModel
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
usedMapping = true
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
// Get providers for the mapped model
|
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||||
providers = util.GetProviderName(mappedModel)
|
c.Set(MappedModelContextKey, mappedModel)
|
||||||
|
resolvedModel = mappedModel
|
||||||
// Continue to handler with remapped model
|
usedMapping = true
|
||||||
goto handleRequest
|
providers = mappedProviders
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No mapping found - check if we have a proxy for fallback
|
// If no mapping applied, check for local providers
|
||||||
|
if !usedMapping {
|
||||||
|
providers = util.GetProviderName(normalizedModel)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// DEFAULT MODE: Check local providers first, then mappings as fallback
|
||||||
|
providers = util.GetProviderName(normalizedModel)
|
||||||
|
|
||||||
|
if len(providers) == 0 {
|
||||||
|
// No providers configured - check if we have a model mapping
|
||||||
|
if fh.modelMapper != nil {
|
||||||
|
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||||
|
// Mapping found - check if we have a provider for the mapped model
|
||||||
|
mappedProviders := util.GetProviderName(mappedModel)
|
||||||
|
if len(mappedProviders) > 0 {
|
||||||
|
// Mapping found and provider available - rewrite the model in request body
|
||||||
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||||
|
c.Set(MappedModelContextKey, mappedModel)
|
||||||
|
resolvedModel = mappedModel
|
||||||
|
usedMapping = true
|
||||||
|
providers = mappedProviders
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no providers available, fallback to ampcode.com
|
||||||
|
if len(providers) == 0 {
|
||||||
proxy := fh.getProxy()
|
proxy := fh.getProxy()
|
||||||
if proxy != nil {
|
if proxy != nil {
|
||||||
// Log: Forwarding to ampcode.com (uses Amp credits)
|
// Log: Forwarding to ampcode.com (uses Amp credits)
|
||||||
@@ -169,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
handleRequest:
|
|
||||||
|
|
||||||
// Log the routing decision
|
// Log the routing decision
|
||||||
providerName := ""
|
providerName := ""
|
||||||
if len(providers) > 0 {
|
if len(providers) > 0 {
|
||||||
@@ -179,59 +220,62 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
|
|
||||||
if usedMapping {
|
if usedMapping {
|
||||||
// Log: Model was mapped to another model
|
// Log: Model was mapped to another model
|
||||||
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||||
|
c.Writer = rewriter
|
||||||
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
|
filterAntropicBetaHeader(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
rewriter.Flush()
|
||||||
|
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
||||||
} else if len(providers) > 0 {
|
} else if len(providers) > 0 {
|
||||||
// Log: Using local provider (free)
|
// Log: Using local provider (free)
|
||||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||||
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
|
filterAntropicBetaHeader(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
} else {
|
||||||
|
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
|
||||||
// Filter Anthropic-Beta header to remove features requiring special subscription
|
|
||||||
// This is needed when using local providers (bypassing the Amp proxy)
|
|
||||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
|
||||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
|
||||||
if filtered != "" {
|
|
||||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
|
||||||
} else {
|
|
||||||
c.Request.Header.Del("Anthropic-Beta")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
handler(c)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteModelInBody replaces the model name in a JSON request body
|
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
||||||
func rewriteModelInBody(body []byte, newModel string) []byte {
|
// This is needed when using local providers (bypassing the Amp proxy)
|
||||||
var payload map[string]interface{}
|
func filterAntropicBetaHeader(c *gin.Context) {
|
||||||
if err := json.Unmarshal(body, &payload); err != nil {
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||||
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
|
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||||
|
} else {
|
||||||
|
c.Request.Header.Del("Anthropic-Beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteModelInRequest replaces the model name in a JSON request body
|
||||||
|
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
||||||
|
if !gjson.GetBytes(body, "model").Exists() {
|
||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
result, err := sjson.SetBytes(body, "model", newModel)
|
||||||
if _, exists := payload["model"]; exists {
|
if err != nil {
|
||||||
payload["model"] = newModel
|
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
||||||
newBody, err := json.Marshal(payload)
|
return body
|
||||||
if err != nil {
|
|
||||||
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
return body
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||||
var payload map[string]interface{}
|
// Check common model field names
|
||||||
if err := json.Unmarshal(body, &payload); err == nil {
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||||
// Check common model field names
|
return result.String()
|
||||||
if model, ok := payload["model"].(string); ok {
|
|
||||||
return model
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Gemini requests, model is in the URL path
|
// For Gemini requests, model is in the URL path
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
||||||
@@ -15,16 +14,31 @@ import (
|
|||||||
//
|
//
|
||||||
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
||||||
// so the standard Gemini handler can process it.
|
// so the standard Gemini handler can process it.
|
||||||
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
|
//
|
||||||
|
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
|
||||||
|
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// Get the full path from the catch-all parameter
|
// Get the full path from the catch-all parameter
|
||||||
path := c.Param("path")
|
path := c.Param("path")
|
||||||
|
|
||||||
// Extract model:method from AMP CLI path format
|
// Extract model:method from AMP CLI path format
|
||||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
const modelsPrefix = "/models/"
|
||||||
// Extract everything after "/models/"
|
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
|
||||||
actionPart := path[idx+8:] // Skip "/models/"
|
// Extract everything after modelsPrefix
|
||||||
|
actionPart := path[idx+len(modelsPrefix):]
|
||||||
|
|
||||||
|
// Check if model was mapped by FallbackHandler
|
||||||
|
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
||||||
|
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
||||||
|
// Replace the model part in the action
|
||||||
|
// actionPart is like "model-name:method"
|
||||||
|
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
||||||
|
method := actionPart[colonIdx:] // ":method"
|
||||||
|
actionPart = strModel + method
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set this as the :action parameter that the Gemini handler expects
|
// Set this as the :action parameter that the Gemini handler expects
|
||||||
c.Params = append(c.Params, gin.Param{
|
c.Params = append(c.Params, gin.Param{
|
||||||
@@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
|
|||||||
Value: actionPart,
|
Value: actionPart,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Call the standard Gemini handler
|
// Call the handler
|
||||||
geminiHandler.GeminiHandler(c)
|
handler(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
mappedModel string // empty string means no mapping
|
||||||
|
expectedAction string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no_mapping_uses_url_model",
|
||||||
|
path: "/publishers/google/models/gemini-pro:generateContent",
|
||||||
|
mappedModel: "",
|
||||||
|
expectedAction: "gemini-pro:generateContent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mapped_model_replaces_url_model",
|
||||||
|
path: "/publishers/google/models/gemini-exp:generateContent",
|
||||||
|
mappedModel: "gemini-2.0-flash",
|
||||||
|
expectedAction: "gemini-2.0-flash:generateContent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mapping_preserves_method",
|
||||||
|
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
|
||||||
|
mappedModel: "gemini-flash",
|
||||||
|
expectedAction: "gemini-flash:streamGenerateContent",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedAction string
|
||||||
|
|
||||||
|
mockGeminiHandler := func(c *gin.Context) {
|
||||||
|
capturedAction = c.Param("action")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the actual createGeminiBridgeHandler function
|
||||||
|
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
if tt.mappedModel != "" {
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set(MappedModelContextKey, tt.mappedModel)
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if capturedAction != tt.expectedAction {
|
||||||
|
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mockHandler := func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
}
|
||||||
|
bridgeHandler := createGeminiBridgeHandler(mockHandler)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -66,7 +66,6 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||||
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
|
|
||||||
return targetModel
|
return targetModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
98
internal/api/modules/amp/response_rewriter.go
Normal file
98
internal/api/modules/amp/response_rewriter.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||||
|
// It's used to rewrite model names in responses when model mapping is used
|
||||||
|
type ResponseRewriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
originalModel string
|
||||||
|
isStreaming bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseRewriter creates a new response rewriter for model name substitution
|
||||||
|
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||||
|
return &ResponseRewriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
originalModel: originalModel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write intercepts response writes and buffers them for model name replacement
|
||||||
|
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||||
|
// Detect streaming on first write
|
||||||
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||||
|
contentType := rw.Header().Get("Content-Type")
|
||||||
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||||
|
strings.Contains(contentType, "stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||||
|
}
|
||||||
|
return rw.body.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes the buffered response with model names rewritten
|
||||||
|
func (rw *ResponseRewriter) Flush() {
|
||||||
|
if rw.isStreaming {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rw.body.Len() > 0 {
|
||||||
|
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
||||||
|
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelFieldPaths lists all JSON paths where model name may appear
|
||||||
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||||
|
|
||||||
|
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||||
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
if rw.originalModel == "" {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
||||||
|
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
|
if rw.originalModel == "" {
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE format: "data: {json}\n\n"
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i, line := range lines {
|
||||||
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
// Rewrite JSON in the data line
|
||||||
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||||
|
lines[i] = append([]byte("data: "), rewritten...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Join(lines, []byte("\n"))
|
||||||
|
}
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||||
@@ -78,6 +80,21 @@ func noCORSMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// managementAvailabilityMiddleware short-circuits management routes when the upstream
|
||||||
|
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
|
||||||
|
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if m.getProxy() == nil {
|
||||||
|
logging.SkipGinRequestLogging(c)
|
||||||
|
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
|
||||||
|
"error": "amp upstream proxy not available",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// registerManagementRoutes registers Amp management proxy routes
|
// registerManagementRoutes registers Amp management proxy routes
|
||||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||||
@@ -85,19 +102,28 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
ampAPI := engine.Group("/api")
|
ampAPI := engine.Group("/api")
|
||||||
|
|
||||||
// Always disable CORS for management routes to prevent browser-based attacks
|
// Always disable CORS for management routes to prevent browser-based attacks
|
||||||
ampAPI.Use(noCORSMiddleware())
|
ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
|
||||||
|
|
||||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||||
|
|
||||||
if m.IsRestrictedToLocalhost() {
|
if !m.IsRestrictedToLocalhost() {
|
||||||
log.Info("amp management routes restricted to localhost only (CORS disabled)")
|
|
||||||
} else {
|
|
||||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||||
proxyHandler := func(c *gin.Context) {
|
proxyHandler := func(c *gin.Context) {
|
||||||
|
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
|
// Upstream already wrote the status (often 404) before the client/stream ended.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
proxy := m.getProxy()
|
proxy := m.getProxy()
|
||||||
if proxy == nil {
|
if proxy == nil {
|
||||||
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||||
@@ -127,8 +153,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
|
|
||||||
// Root-level routes that AMP CLI expects without /api prefix
|
// Root-level routes that AMP CLI expects without /api prefix
|
||||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||||
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||||
|
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||||
|
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||||
|
|
||||||
// Root-level auth routes for CLI login flow
|
// Root-level auth routes for CLI login flow
|
||||||
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
|
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
|
||||||
@@ -141,30 +169,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
// We bridge these to our standard Gemini handler to enable local OAuth.
|
// We bridge these to our standard Gemini handler to enable local OAuth.
|
||||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
return m.getProxy()
|
return m.getProxy()
|
||||||
})
|
}, m.modelMapper, m.forceModelMappings)
|
||||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
|
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||||
|
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
||||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
if c.Request.Method == "POST" {
|
if c.Request.Method == "POST" {
|
||||||
// Attempt to extract the model name from the AMP-style path
|
|
||||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
|
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
// FallbackHandler will check provider/mapping and proxy if needed
|
||||||
modelPart = modelPart[:colonIdx]
|
geminiV1Beta1Handler(c)
|
||||||
}
|
return
|
||||||
if modelPart != "" {
|
|
||||||
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
|
|
||||||
// Only handle locally when we have a provider; otherwise fall back to proxy
|
|
||||||
if providers := util.GetProviderName(normalized); len(providers) > 0 {
|
|
||||||
geminiV1Beta1Handler(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Non-POST or no local provider available -> proxy upstream
|
// Non-POST or no local provider available -> proxy upstream
|
||||||
@@ -190,7 +210,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
// Also includes model mapping support for routing unavailable models to alternatives
|
// Also includes model mapping support for routing unavailable models to alternatives
|
||||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
return m.getProxy()
|
return m.getProxy()
|
||||||
}, m.modelMapper)
|
}, m.modelMapper, m.forceModelMappings)
|
||||||
|
|
||||||
// Provider-specific routes under /api/provider/:provider
|
// Provider-specific routes under /api/provider/:provider
|
||||||
ampProviders := engine.Group("/api/provider")
|
ampProviders := engine.Group("/api/provider")
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
|||||||
{"/api/meta", http.MethodGet},
|
{"/api/meta", http.MethodGet},
|
||||||
{"/api/telemetry", http.MethodGet},
|
{"/api/telemetry", http.MethodGet},
|
||||||
{"/api/threads", http.MethodGet},
|
{"/api/threads", http.MethodGet},
|
||||||
|
{"/threads/", http.MethodGet},
|
||||||
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
||||||
{"/api/otel", http.MethodGet},
|
{"/api/otel", http.MethodGet},
|
||||||
{"/api/tab", http.MethodGet},
|
{"/api/tab", http.MethodGet},
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
|
|
||||||
// Create HTTP server
|
// Create HTTP server
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||||
Handler: engine,
|
Handler: engine,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,6 +472,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||||
|
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
||||||
|
|
||||||
mgmt.GET("/debug", s.mgmt.GetDebug)
|
mgmt.GET("/debug", s.mgmt.GetDebug)
|
||||||
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
||||||
@@ -519,6 +520,26 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||||
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||||
|
|
||||||
|
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
|
||||||
|
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
|
||||||
|
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
|
||||||
|
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
|
||||||
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||||
|
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||||
|
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
|
||||||
|
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
|
||||||
|
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
|
||||||
|
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
|
||||||
|
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||||
|
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
|
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
|
|
||||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||||
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
||||||
@@ -901,7 +922,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
for _, p := range cfg.OpenAICompatibility {
|
for _, p := range cfg.OpenAICompatibility {
|
||||||
providerNames = append(providerNames, p.Name)
|
providerNames = append(providerNames, p.Name)
|
||||||
}
|
}
|
||||||
s.handlers.OpenAICompatProviders = providerNames
|
s.handlers.SetOpenAICompatProviders(providerNames)
|
||||||
|
|
||||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||||
|
|
||||||
|
|||||||
@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
platformURL = "https://console.anthropic.com/"
|
platformURL = "https://console.anthropic.com/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||||
|
if !isValidURL(platformURL) {
|
||||||
|
platformURL = "https://console.anthropic.com/"
|
||||||
|
}
|
||||||
|
|
||||||
// Generate success page HTML with dynamic content
|
// Generate success page HTML with dynamic content
|
||||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||||
|
|
||||||
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||||
|
func isValidURL(urlStr string) bool {
|
||||||
|
urlStr = strings.TrimSpace(urlStr)
|
||||||
|
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||||
|
}
|
||||||
|
|
||||||
// generateSuccessHTML creates the HTML content for the success page.
|
// generateSuccessHTML creates the HTML content for the success page.
|
||||||
// It customizes the page based on whether additional setup is required
|
// It customizes the page based on whether additional setup is required
|
||||||
// and includes a link to the platform.
|
// and includes a link to the platform.
|
||||||
|
|||||||
@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
platformURL = "https://platform.openai.com"
|
platformURL = "https://platform.openai.com"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||||
|
if !isValidURL(platformURL) {
|
||||||
|
platformURL = "https://platform.openai.com"
|
||||||
|
}
|
||||||
|
|
||||||
// Generate success page HTML with dynamic content
|
// Generate success page HTML with dynamic content
|
||||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||||
|
|
||||||
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||||
|
func isValidURL(urlStr string) bool {
|
||||||
|
urlStr = strings.TrimSpace(urlStr)
|
||||||
|
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||||
|
}
|
||||||
|
|
||||||
// generateSuccessHTML creates the HTML content for the success page.
|
// generateSuccessHTML creates the HTML content for the success page.
|
||||||
// It customizes the page based on whether additional setup is required
|
// It customizes the page based on whether additional setup is required
|
||||||
// and includes a link to the platform.
|
// and includes a link to the platform.
|
||||||
|
|||||||
@@ -76,7 +76,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
auth := &proxy.Auth{User: username, Password: password}
|
auth := &proxy.Auth{User: username, Password: password}
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||||
if errSOCKS5 != nil {
|
if errSOCKS5 != nil {
|
||||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||||
|
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||||
}
|
}
|
||||||
transport = &http.Transport{
|
transport = &http.Transport{
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
@@ -238,7 +239,11 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
// Start the server in a goroutine.
|
// Start the server in a goroutine.
|
||||||
go func() {
|
go func() {
|
||||||
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatalf("ListenAndServe(): %v", err)
|
log.Errorf("ListenAndServe(): %v", err)
|
||||||
|
select {
|
||||||
|
case errChan <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,10 +29,21 @@ const (
|
|||||||
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
||||||
|
|
||||||
// Client credentials provided by iFlow for the Code Assist integration.
|
// Client credentials provided by iFlow for the Code Assist integration.
|
||||||
iFlowOAuthClientID = "10009311001"
|
iFlowOAuthClientID = "10009311001"
|
||||||
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
|
||||||
|
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getIFlowClientSecret returns the iFlow OAuth client secret.
|
||||||
|
// It first checks the IFLOW_CLIENT_SECRET environment variable,
|
||||||
|
// falling back to the default value if not set.
|
||||||
|
func getIFlowClientSecret() string {
|
||||||
|
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
|
||||||
|
return secret
|
||||||
|
}
|
||||||
|
return defaultIFlowClientSecret
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
||||||
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
||||||
|
|
||||||
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
|
|||||||
form.Set("code", code)
|
form.Set("code", code)
|
||||||
form.Set("redirect_uri", redirectURI)
|
form.Set("redirect_uri", redirectURI)
|
||||||
form.Set("client_id", iFlowOAuthClientID)
|
form.Set("client_id", iFlowOAuthClientID)
|
||||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
form.Set("client_secret", getIFlowClientSecret())
|
||||||
|
|
||||||
req, err := ia.newTokenRequest(ctx, form)
|
req, err := ia.newTokenRequest(ctx, form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
|
|||||||
form.Set("grant_type", "refresh_token")
|
form.Set("grant_type", "refresh_token")
|
||||||
form.Set("refresh_token", refreshToken)
|
form.Set("refresh_token", refreshToken)
|
||||||
form.Set("client_id", iFlowOAuthClientID)
|
form.Set("client_id", iFlowOAuthClientID)
|
||||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
form.Set("client_secret", getIFlowClientSecret())
|
||||||
|
|
||||||
req, err := ia.newTokenRequest(ctx, form)
|
req, err := ia.newTokenRequest(ctx, form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
|
|||||||
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
|
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("Authorization", "Basic "+basic)
|
req.Header.Set("Authorization", "Basic "+basic)
|
||||||
@@ -309,17 +321,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string)
|
|||||||
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
|
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// First, get initial API key information using GET request
|
// First, get initial API key information using GET request to obtain the name
|
||||||
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
|
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
|
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to token data format
|
// Refresh the API key using POST request
|
||||||
|
refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to token data format using refreshed key
|
||||||
data := &IFlowTokenData{
|
data := &IFlowTokenData{
|
||||||
APIKey: keyInfo.APIKey,
|
APIKey: refreshedKeyInfo.APIKey,
|
||||||
Expire: keyInfo.ExpireTime,
|
Expire: refreshedKeyInfo.ExpireTime,
|
||||||
Email: keyInfo.Name,
|
Email: refreshedKeyInfo.Name,
|
||||||
Cookie: cookie,
|
Cookie: cookie,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
301
internal/auth/kiro/aws.go
Normal file
301
internal/auth/kiro/aws.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||||
|
// It includes interfaces and implementations for token storage and authentication methods.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||||
|
type PKCECodes struct {
|
||||||
|
// CodeVerifier is the cryptographically random string used to correlate
|
||||||
|
// the authorization request to the token request
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||||
|
type KiroTokenData struct {
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
// ExpiresAt is the timestamp when the token expires
|
||||||
|
ExpiresAt string `json:"expiresAt"`
|
||||||
|
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
|
||||||
|
AuthMethod string `json:"authMethod"`
|
||||||
|
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
// ClientID is the OIDC client ID (needed for token refresh)
|
||||||
|
ClientID string `json:"clientId,omitempty"`
|
||||||
|
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||||
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
|
// Email is the user's email address (used for file naming)
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||||
|
type KiroAuthBundle struct {
|
||||||
|
// TokenData contains the OAuth tokens from the authentication flow
|
||||||
|
TokenData KiroTokenData `json:"token_data"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||||
|
type KiroUsageInfo struct {
|
||||||
|
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||||
|
SubscriptionTitle string `json:"subscription_title"`
|
||||||
|
// CurrentUsage is the current credit usage
|
||||||
|
CurrentUsage float64 `json:"current_usage"`
|
||||||
|
// UsageLimit is the maximum credit limit
|
||||||
|
UsageLimit float64 `json:"usage_limit"`
|
||||||
|
// NextReset is the timestamp of the next usage reset
|
||||||
|
NextReset string `json:"next_reset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroModel represents a model available through the CodeWhisperer API
|
||||||
|
type KiroModel struct {
|
||||||
|
// ModelID is the unique identifier for the model
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
// ModelName is the human-readable name
|
||||||
|
ModelName string `json:"modelName"`
|
||||||
|
// Description is the model description
|
||||||
|
Description string `json:"description"`
|
||||||
|
// RateMultiplier is the credit multiplier for this model
|
||||||
|
RateMultiplier float64 `json:"rateMultiplier"`
|
||||||
|
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||||
|
RateUnit string `json:"rateUnit"`
|
||||||
|
// MaxInputTokens is the maximum input token limit
|
||||||
|
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||||
|
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||||
|
|
||||||
|
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||||
|
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||||
|
data, err := os.ReadFile(tokenPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||||
|
// This supports multiple accounts by allowing different token files.
|
||||||
|
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||||
|
// Expand ~ to home directory
|
||||||
|
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tokenPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty in token file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||||
|
// This supports multiple accounts by finding all token files.
|
||||||
|
func ListKiroTokenFiles() ([]string, error) {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||||
|
|
||||||
|
// Check if directory exists
|
||||||
|
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||||
|
return nil, nil // No token files
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(cacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenFiles []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := entry.Name()
|
||||||
|
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||||
|
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||||
|
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenFiles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||||
|
// This supports multiple accounts.
|
||||||
|
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||||
|
files, err := ListKiroTokenFiles()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*KiroTokenData
|
||||||
|
for _, file := range files {
|
||||||
|
token, err := LoadKiroTokenFromPath(file)
|
||||||
|
if err != nil {
|
||||||
|
// Skip invalid token files
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTClaims represents the claims we care about from a JWT token.
|
||||||
|
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||||
|
type JWTClaims struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Sub string `json:"sub,omitempty"`
|
||||||
|
PreferredUser string `json:"preferred_username,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Iss string `json:"iss,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||||
|
// JWT tokens typically have format: header.payload.signature
|
||||||
|
// The payload is base64url-encoded JSON containing user claims.
|
||||||
|
func ExtractEmailFromJWT(accessToken string) string {
|
||||||
|
if accessToken == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT format: header.payload.signature
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the payload (second part)
|
||||||
|
payload := parts[1]
|
||||||
|
|
||||||
|
// Add padding if needed (base64url requires padding)
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
// Try RawURLEncoding (no padding)
|
||||||
|
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims JWTClaims
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return email if available
|
||||||
|
if claims.Email != "" {
|
||||||
|
return claims.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to preferred_username (some providers use this)
|
||||||
|
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||||
|
return claims.PreferredUser
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to sub if it looks like an email
|
||||||
|
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||||
|
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||||
|
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||||
|
func SanitizeEmailForFilename(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result := email
|
||||||
|
|
||||||
|
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||||
|
// This prevents encoded characters from bypassing the sanitization.
|
||||||
|
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||||
|
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||||
|
result = strings.ReplaceAll(result, "%2f", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||||
|
result = strings.ReplaceAll(result, "%5c", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||||
|
result = strings.ReplaceAll(result, "%2e", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||||
|
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||||
|
|
||||||
|
// Replace characters that are problematic in filenames
|
||||||
|
// Keep @ and . in middle but replace other special characters
|
||||||
|
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||||
|
result = strings.ReplaceAll(result, char, "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent path traversal: replace leading dots in each path component
|
||||||
|
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||||
|
parts := strings.Split(result, "_")
|
||||||
|
for i, part := range parts {
|
||||||
|
for strings.HasPrefix(part, ".") {
|
||||||
|
part = "_" + part[1:]
|
||||||
|
}
|
||||||
|
parts[i] = part
|
||||||
|
}
|
||||||
|
result = strings.Join(parts, "_")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
314
internal/auth/kiro/aws_auth.go
Normal file
314
internal/auth/kiro/aws_auth.go
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||||
|
// This package implements token loading, refresh, and API communication with CodeWhisperer.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
|
||||||
|
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
|
||||||
|
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
|
||||||
|
// for their respective API operations.
|
||||||
|
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||||
|
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
|
||||||
|
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
|
||||||
|
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
|
||||||
|
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||||
|
// It provides methods for loading tokens, refreshing expired tokens,
|
||||||
|
// and communicating with the CodeWhisperer API.
|
||||||
|
type KiroAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKiroAuth creates a new Kiro authentication service.
|
||||||
|
// It initializes the HTTP client with proxy settings from the configuration.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy settings
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroAuth: A new Kiro authentication service instance
|
||||||
|
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||||
|
return &KiroAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||||
|
endpoint: awsKiroEndpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTokenFromFile loads token data from a file path.
|
||||||
|
// This method reads and parses the token file, expanding ~ to the home directory.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenFile: Path to the token file (supports ~ expansion)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroTokenData: The parsed token data
|
||||||
|
// - error: An error if file reading or parsing fails
|
||||||
|
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
|
||||||
|
// Expand ~ to home directory
|
||||||
|
if strings.HasPrefix(tokenFile, "~") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenFile = filepath.Join(home, tokenFile[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tokenFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenExpired checks if the token has expired.
|
||||||
|
// This method parses the expiration timestamp and compares it with the current time.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenData: The token data to check
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - bool: True if the token has expired, false otherwise
|
||||||
|
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
||||||
|
if tokenData.ExpiresAt == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
// Try alternate format
|
||||||
|
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().After(expiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest sends a request to the CodeWhisperer API.
|
||||||
|
// This is an internal method for making authenticated API calls.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
|
||||||
|
// - accessToken: The OAuth access token
|
||||||
|
// - payload: The request payload
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []byte: The response body
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
|
||||||
|
jsonBody, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", target)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := k.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("failed to close response body: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
|
||||||
|
// This method fetches the current usage statistics and subscription information.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data containing access token and profile ARN
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroUsageInfo: The usage information
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": tokenData.ProfileArn,
|
||||||
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
SubscriptionInfo struct {
|
||||||
|
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||||
|
} `json:"subscriptionInfo"`
|
||||||
|
UsageBreakdownList []struct {
|
||||||
|
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||||
|
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||||
|
} `json:"usageBreakdownList"`
|
||||||
|
NextDateReset float64 `json:"nextDateReset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &KiroUsageInfo{
|
||||||
|
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
|
||||||
|
NextReset: fmt.Sprintf("%v", result.NextDateReset),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.UsageBreakdownList) > 0 {
|
||||||
|
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
|
||||||
|
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAvailableModels retrieves available models from the CodeWhisperer API.
|
||||||
|
// This method fetches the list of AI models available for the authenticated user.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data containing access token and profile ARN
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*KiroModel: The list of available models
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": tokenData.ProfileArn,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Models []struct {
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
ModelName string `json:"modelName"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
RateMultiplier float64 `json:"rateMultiplier"`
|
||||||
|
RateUnit string `json:"rateUnit"`
|
||||||
|
TokenLimits struct {
|
||||||
|
MaxInputTokens int `json:"maxInputTokens"`
|
||||||
|
} `json:"tokenLimits"`
|
||||||
|
} `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]*KiroModel, 0, len(result.Models))
|
||||||
|
for _, m := range result.Models {
|
||||||
|
models = append(models, &KiroModel{
|
||||||
|
ModelID: m.ModelID,
|
||||||
|
ModelName: m.ModelName,
|
||||||
|
Description: m.Description,
|
||||||
|
RateMultiplier: m.RateMultiplier,
|
||||||
|
RateUnit: m.RateUnit,
|
||||||
|
MaxInputTokens: m.TokenLimits.MaxInputTokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage creates a new KiroTokenStorage from token data.
|
||||||
|
// This method converts the token data into a storage structure suitable for persistence.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenData: The token data to convert
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroTokenStorage: A new token storage instance
|
||||||
|
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
|
||||||
|
return &KiroTokenStorage{
|
||||||
|
AccessToken: tokenData.AccessToken,
|
||||||
|
RefreshToken: tokenData.RefreshToken,
|
||||||
|
ProfileArn: tokenData.ProfileArn,
|
||||||
|
ExpiresAt: tokenData.ExpiresAt,
|
||||||
|
AuthMethod: tokenData.AuthMethod,
|
||||||
|
Provider: tokenData.Provider,
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken checks if the token is valid by making a test API call.
|
||||||
|
// This method verifies the token by attempting to fetch usage limits.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data to validate
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the token is invalid
|
||||||
|
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
|
||||||
|
_, err := k.GetUsageLimits(ctx, tokenData)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTokenStorage updates an existing token storage with new token data.
|
||||||
|
// This method refreshes the token storage with newly obtained access and refresh tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - storage: The existing token storage to update
|
||||||
|
// - tokenData: The new token data to apply
|
||||||
|
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
|
||||||
|
storage.AccessToken = tokenData.AccessToken
|
||||||
|
storage.RefreshToken = tokenData.RefreshToken
|
||||||
|
storage.ProfileArn = tokenData.ProfileArn
|
||||||
|
storage.ExpiresAt = tokenData.ExpiresAt
|
||||||
|
storage.AuthMethod = tokenData.AuthMethod
|
||||||
|
storage.Provider = tokenData.Provider
|
||||||
|
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||||
|
}
|
||||||
161
internal/auth/kiro/aws_test.go
Normal file
161
internal/auth/kiro/aws_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractEmailFromJWT(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty token",
|
||||||
|
token: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid token format",
|
||||||
|
token: "not.a.valid.jwt",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid token - not base64",
|
||||||
|
token: "xxx.yyy.zzz",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid JWT with email",
|
||||||
|
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
|
||||||
|
expected: "test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT without email but with preferred_username",
|
||||||
|
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
|
||||||
|
expected: "user@domain.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT with email-like sub",
|
||||||
|
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
|
||||||
|
expected: "another@test.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT without any email fields",
|
||||||
|
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExtractEmailFromJWT(tt.token)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeEmailForFilename(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty email",
|
||||||
|
email: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple email",
|
||||||
|
email: "user@example.com",
|
||||||
|
expected: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with space",
|
||||||
|
email: "user name@example.com",
|
||||||
|
expected: "user_name@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with special chars",
|
||||||
|
email: "user:name@example.com",
|
||||||
|
expected: "user_name@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with multiple special chars",
|
||||||
|
email: "user/name:test@example.com",
|
||||||
|
expected: "user_name_test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path traversal attempt",
|
||||||
|
email: "../../../etc/passwd",
|
||||||
|
expected: "_.__.__._etc_passwd",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path traversal with backslash",
|
||||||
|
email: `..\..\..\..\windows\system32`,
|
||||||
|
expected: "_.__.__.__._windows_system32",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Null byte injection attempt",
|
||||||
|
email: "user\x00@evil.com",
|
||||||
|
expected: "user_@evil.com",
|
||||||
|
},
|
||||||
|
// URL-encoded path traversal tests
|
||||||
|
{
|
||||||
|
name: "URL-encoded slash",
|
||||||
|
email: "user%2Fpath@example.com",
|
||||||
|
expected: "user_path@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded backslash",
|
||||||
|
email: "user%5Cpath@example.com",
|
||||||
|
expected: "user_path@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded dot",
|
||||||
|
email: "%2E%2E%2Fetc%2Fpasswd",
|
||||||
|
expected: "___etc_passwd",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded null",
|
||||||
|
email: "user%00@evil.com",
|
||||||
|
expected: "user_@evil.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Double URL-encoding attack",
|
||||||
|
email: "%252F%252E%252E",
|
||||||
|
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed case URL-encoding",
|
||||||
|
email: "%2f%2F%5c%5C",
|
||||||
|
expected: "____",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeEmailForFilename(tt.email)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestJWT creates a test JWT token with the given claims
|
||||||
|
func createTestJWT(claims map[string]any) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
||||||
|
|
||||||
|
payloadBytes, _ := json.Marshal(claims)
|
||||||
|
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
|
||||||
|
|
||||||
|
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
|
||||||
|
|
||||||
|
return header + "." + payload + "." + signature
|
||||||
|
}
|
||||||
296
internal/auth/kiro/oauth.go
Normal file
296
internal/auth/kiro/oauth.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Kiro auth endpoint
|
||||||
|
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
|
||||||
|
// Default callback port
|
||||||
|
defaultCallbackPort = 9876
|
||||||
|
|
||||||
|
// Auth timeout
|
||||||
|
authTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroTokenResponse represents the response from Kiro token endpoint.
|
||||||
|
type KiroTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||||
|
type KiroOAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||||
|
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &KiroOAuth{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeVerifier generates a random code verifier for PKCE.
|
||||||
|
func generateCodeVerifier() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeChallenge generates the code challenge from verifier.
|
||||||
|
func generateCodeChallenge(verifier string) string {
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState generates a random state parameter.
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthResult contains the authorization code and state from callback.
|
||||||
|
type AuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||||
|
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
|
||||||
|
// Try to find an available port - use localhost like Kiro does
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
|
||||||
|
if err != nil {
|
||||||
|
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||||
|
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
|
||||||
|
listener, err = net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
// Use http scheme for local callback server
|
||||||
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||||
|
resultChan := make(chan AuthResult, 1)
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
errParam := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
if errParam != "" {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||||
|
resultChan <- AuthResult{Error: errParam}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if state != expectedState {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||||
|
resultChan <- AuthResult{Error: "state mismatch"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
|
||||||
|
resultChan <- AuthResult{Code: code, State: state}
|
||||||
|
})
|
||||||
|
|
||||||
|
server.Handler = mux
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Debugf("callback server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(authTimeout):
|
||||||
|
case <-resultChan:
|
||||||
|
}
|
||||||
|
_ = server.Shutdown(context.Background())
|
||||||
|
}()
|
||||||
|
|
||||||
|
return redirectURI, resultChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
|
||||||
|
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||||
|
return ssoClient.LoginWithBuilderID(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||||
|
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": codeVerifier,
|
||||||
|
"redirect_uri": redirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenURL := kiroAuthEndpoint + "/oauth/token"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp KiroTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes an expired access token.
|
||||||
|
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshURL := kiroAuthEndpoint + "/refreshToken"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp KiroTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
socialClient := NewSocialAuthClient(o.cfg)
|
||||||
|
return socialClient.LoginWithGoogle(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
socialClient := NewSocialAuthClient(o.cfg)
|
||||||
|
return socialClient.LoginWithGitHub(ctx)
|
||||||
|
}
|
||||||
725
internal/auth/kiro/protocol_handler.go
Normal file
725
internal/auth/kiro/protocol_handler.go
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
// Package kiro provides custom protocol handler registration for Kiro OAuth.
|
||||||
|
// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KiroProtocol is the custom URI scheme used by Kiro
|
||||||
|
KiroProtocol = "kiro"
|
||||||
|
|
||||||
|
// KiroAuthority is the URI authority for authentication callbacks
|
||||||
|
KiroAuthority = "kiro.kiroAgent"
|
||||||
|
|
||||||
|
// KiroAuthPath is the path for successful authentication
|
||||||
|
KiroAuthPath = "/authenticate-success"
|
||||||
|
|
||||||
|
// KiroRedirectURI is the full redirect URI for social auth
|
||||||
|
KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
|
||||||
|
|
||||||
|
// DefaultHandlerPort is the default port for the local callback server
|
||||||
|
DefaultHandlerPort = 19876
|
||||||
|
|
||||||
|
// HandlerTimeout is how long to wait for the OAuth callback
|
||||||
|
HandlerTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
|
||||||
|
type ProtocolHandler struct {
|
||||||
|
port int
|
||||||
|
server *http.Server
|
||||||
|
listener net.Listener
|
||||||
|
resultChan chan *AuthCallback
|
||||||
|
stopChan chan struct{}
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCallback contains the OAuth callback parameters.
|
||||||
|
type AuthCallback struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProtocolHandler creates a new protocol handler.
|
||||||
|
func NewProtocolHandler() *ProtocolHandler {
|
||||||
|
return &ProtocolHandler{
|
||||||
|
port: DefaultHandlerPort,
|
||||||
|
resultChan: make(chan *AuthCallback, 1),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the local callback server that receives redirects from the protocol handler.
|
||||||
|
func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if h.running {
|
||||||
|
return h.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain any stale results from previous runs
|
||||||
|
select {
|
||||||
|
case <-h.resultChan:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
|
||||||
|
if h.stopChan != nil {
|
||||||
|
select {
|
||||||
|
case <-h.stopChan:
|
||||||
|
// Already closed
|
||||||
|
default:
|
||||||
|
close(h.stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.stopChan = make(chan struct{})
|
||||||
|
|
||||||
|
// Try ports in known range (must match handler script port range)
|
||||||
|
var listener net.Listener
|
||||||
|
var err error
|
||||||
|
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
||||||
|
|
||||||
|
for _, port := range portRange {
|
||||||
|
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listener == nil {
|
||||||
|
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.listener = listener
|
||||||
|
h.port = listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/oauth/callback", h.handleCallback)
|
||||||
|
|
||||||
|
h.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Debugf("kiro protocol handler server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
h.running = true
|
||||||
|
log.Debugf("kiro protocol handler started on port %d", h.port)
|
||||||
|
|
||||||
|
// Auto-shutdown after context done, timeout, or explicit stop
|
||||||
|
// Capture references to prevent race with new Start() calls
|
||||||
|
currentStopChan := h.stopChan
|
||||||
|
currentServer := h.server
|
||||||
|
currentListener := h.listener
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(HandlerTimeout):
|
||||||
|
case <-currentStopChan:
|
||||||
|
return // Already stopped, exit goroutine
|
||||||
|
}
|
||||||
|
// Only stop if this is still the current server/listener instance
|
||||||
|
h.mu.Lock()
|
||||||
|
if h.server == currentServer && h.listener == currentListener {
|
||||||
|
h.mu.Unlock()
|
||||||
|
h.Stop()
|
||||||
|
} else {
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return h.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the callback server.
|
||||||
|
func (h *ProtocolHandler) Stop() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if !h.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal the auto-shutdown goroutine to exit.
|
||||||
|
// This select pattern is safe because stopChan is only modified while holding h.mu,
|
||||||
|
// and we hold the lock here. The select prevents panic from double-close.
|
||||||
|
select {
|
||||||
|
case <-h.stopChan:
|
||||||
|
// Already closed
|
||||||
|
default:
|
||||||
|
close(h.stopChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.server != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = h.server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.running = false
|
||||||
|
log.Debug("kiro protocol handler stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForCallback waits for the OAuth callback and returns the result.
|
||||||
|
func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(HandlerTimeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
case result := <-h.resultChan:
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPort returns the port the handler is listening on.
|
||||||
|
func (h *ProtocolHandler) GetPort() int {
|
||||||
|
return h.port
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCallback processes the OAuth callback from the protocol handler script.
|
||||||
|
func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
errParam := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
result := &AuthCallback{
|
||||||
|
Code: code,
|
||||||
|
State: state,
|
||||||
|
Error: errParam,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send result
|
||||||
|
select {
|
||||||
|
case h.resultChan <- result:
|
||||||
|
default:
|
||||||
|
// Channel full, ignore duplicate callbacks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send success response
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if errParam != "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>Login Failed</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>Login Failed</h1>
|
||||||
|
<p>Error: %s</p>
|
||||||
|
<p>You can close this window.</p>
|
||||||
|
</body>
|
||||||
|
</html>`, html.EscapeString(errParam))
|
||||||
|
} else {
|
||||||
|
fmt.Fprint(w, `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>Login Successful</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>Login Successful!</h1>
|
||||||
|
<p>You can close this window and return to the terminal.</p>
|
||||||
|
<script>window.close();</script>
|
||||||
|
</body>
|
||||||
|
</html>`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
|
||||||
|
func IsProtocolHandlerInstalled() bool {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return isLinuxHandlerInstalled()
|
||||||
|
case "windows":
|
||||||
|
return isWindowsHandlerInstalled()
|
||||||
|
case "darwin":
|
||||||
|
return isDarwinHandlerInstalled()
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
|
||||||
|
func InstallProtocolHandler(handlerPort int) error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return installLinuxHandler(handlerPort)
|
||||||
|
case "windows":
|
||||||
|
return installWindowsHandler(handlerPort)
|
||||||
|
case "darwin":
|
||||||
|
return installDarwinHandler(handlerPort)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UninstallProtocolHandler removes the kiro:// protocol handler.
|
||||||
|
func UninstallProtocolHandler() error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return uninstallLinuxHandler()
|
||||||
|
case "windows":
|
||||||
|
return uninstallWindowsHandler()
|
||||||
|
case "darwin":
|
||||||
|
return uninstallDarwinHandler()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Linux Implementation ---
|
||||||
|
|
||||||
|
func getLinuxDesktopPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLinuxHandlerScriptPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLinuxHandlerInstalled() bool {
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
_, err := os.Stat(desktopPath)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLinuxHandler(handlerPort int) error {
|
||||||
|
// Create directories
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
binDir := filepath.Join(homeDir, ".local", "bin")
|
||||||
|
appDir := filepath.Join(homeDir, ".local", "share", "applications")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(binDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create bin directory: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(appDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create applications directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler script - tries multiple ports to handle dynamic port allocation
|
||||||
|
scriptPath := getLinuxHandlerScriptPath()
|
||||||
|
scriptContent := fmt.Sprintf(`#!/bin/bash
|
||||||
|
# Kiro OAuth Protocol Handler
|
||||||
|
# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
|
||||||
|
|
||||||
|
URL="$1"
|
||||||
|
|
||||||
|
# Check curl availability
|
||||||
|
if ! command -v curl &> /dev/null; then
|
||||||
|
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract code and state from URL
|
||||||
|
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||||
|
|
||||||
|
# Try CLI proxy on multiple possible ports (default + dynamic range)
|
||||||
|
CLI_OK=0
|
||||||
|
for PORT in %d %d %d %d %d; do
|
||||||
|
if [ -n "$ERROR" ]; then
|
||||||
|
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
|
||||||
|
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||||
|
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# If CLI not available, forward to Kiro IDE
|
||||||
|
if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
|
||||||
|
/usr/share/kiro/kiro --open-url "$URL" &
|
||||||
|
fi
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to write handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create .desktop file
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
desktopContent := fmt.Sprintf(`[Desktop Entry]
|
||||||
|
Name=Kiro OAuth Handler
|
||||||
|
Comment=Handle kiro:// protocol for CLI Proxy API authentication
|
||||||
|
Exec=%s %%u
|
||||||
|
Type=Application
|
||||||
|
Terminal=false
|
||||||
|
NoDisplay=true
|
||||||
|
MimeType=x-scheme-handler/kiro;
|
||||||
|
Categories=Utility;
|
||||||
|
`, scriptPath)
|
||||||
|
|
||||||
|
if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write desktop file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register handler with xdg-mime
|
||||||
|
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update desktop database
|
||||||
|
cmd = exec.Command("update-desktop-database", appDir)
|
||||||
|
_ = cmd.Run() // Ignore errors, not critical
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for Linux")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLinuxHandler() error {
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
scriptPath := getLinuxHandlerScriptPath()
|
||||||
|
|
||||||
|
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove desktop file: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Windows Implementation ---
|
||||||
|
|
||||||
|
func isWindowsHandlerInstalled() bool {
|
||||||
|
// Check registry key existence
|
||||||
|
cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
|
||||||
|
return cmd.Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installWindowsHandler(handlerPort int) error {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler script (PowerShell)
|
||||||
|
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||||
|
if err := os.MkdirAll(scriptDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create script directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
|
||||||
|
scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
|
||||||
|
param([string]$url)
|
||||||
|
|
||||||
|
# Load required assembly for HttpUtility
|
||||||
|
Add-Type -AssemblyName System.Web
|
||||||
|
|
||||||
|
# Parse URL parameters
|
||||||
|
$uri = [System.Uri]$url
|
||||||
|
$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
|
||||||
|
$code = $query["code"]
|
||||||
|
$state = $query["state"]
|
||||||
|
$errorParam = $query["error"]
|
||||||
|
|
||||||
|
# Try multiple ports (default + dynamic range)
|
||||||
|
$ports = @(%d, %d, %d, %d, %d)
|
||||||
|
$success = $false
|
||||||
|
|
||||||
|
foreach ($port in $ports) {
|
||||||
|
if ($success) { break }
|
||||||
|
$callbackUrl = "http://127.0.0.1:$port/oauth/callback"
|
||||||
|
try {
|
||||||
|
if ($errorParam) {
|
||||||
|
$fullUrl = $callbackUrl + "?error=" + $errorParam
|
||||||
|
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||||
|
$success = $true
|
||||||
|
} elseif ($code -and $state) {
|
||||||
|
$fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
|
||||||
|
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||||
|
$success = $true
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
# Try next port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create batch wrapper
|
||||||
|
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
|
||||||
|
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath)
|
||||||
|
|
||||||
|
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write batch wrapper: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register in Windows registry
|
||||||
|
commands := [][]string{
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, args := range commands {
|
||||||
|
cmd := exec.Command(args[0], args[1:]...)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run registry command: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for Windows")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallWindowsHandler() error {
|
||||||
|
// Remove registry keys
|
||||||
|
cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("failed to remove registry key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove scripts
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||||
|
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
|
||||||
|
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- macOS Implementation ---
|
||||||
|
|
||||||
|
func getDarwinAppPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDarwinHandlerInstalled() bool {
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
_, err := os.Stat(appPath)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installDarwinHandler(handlerPort int) error {
|
||||||
|
// Create app bundle structure
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
contentsPath := filepath.Join(appPath, "Contents")
|
||||||
|
macOSPath := filepath.Join(contentsPath, "MacOS")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(macOSPath, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create app bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Info.plist
|
||||||
|
plistPath := filepath.Join(contentsPath, "Info.plist")
|
||||||
|
plistContent := `<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleIdentifier</key>
|
||||||
|
<string>com.cliproxyapi.kiro-oauth-handler</string>
|
||||||
|
<key>CFBundleName</key>
|
||||||
|
<string>KiroOAuthHandler</string>
|
||||||
|
<key>CFBundleExecutable</key>
|
||||||
|
<string>kiro-oauth-handler</string>
|
||||||
|
<key>CFBundleVersion</key>
|
||||||
|
<string>1.0</string>
|
||||||
|
<key>CFBundleURLTypes</key>
|
||||||
|
<array>
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleURLName</key>
|
||||||
|
<string>Kiro Protocol</string>
|
||||||
|
<key>CFBundleURLSchemes</key>
|
||||||
|
<array>
|
||||||
|
<string>kiro</string>
|
||||||
|
</array>
|
||||||
|
</dict>
|
||||||
|
</array>
|
||||||
|
<key>LSBackgroundOnly</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
</plist>`
|
||||||
|
|
||||||
|
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write Info.plist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create executable script - tries multiple ports to handle dynamic port allocation
|
||||||
|
execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
|
||||||
|
execContent := fmt.Sprintf(`#!/bin/bash
|
||||||
|
# Kiro OAuth Protocol Handler for macOS
|
||||||
|
|
||||||
|
URL="$1"
|
||||||
|
|
||||||
|
# Check curl availability (should always exist on macOS)
|
||||||
|
if [ ! -x /usr/bin/curl ]; then
|
||||||
|
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract code and state from URL
|
||||||
|
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||||
|
|
||||||
|
# Try multiple ports (default + dynamic range)
|
||||||
|
for PORT in %d %d %d %d %d; do
|
||||||
|
if [ -n "$ERROR" ]; then
|
||||||
|
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
|
||||||
|
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||||
|
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to write executable: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the app with Launch Services
|
||||||
|
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||||
|
"-f", appPath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("lsregister failed (handler may still work): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallDarwinHandler() error {
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
|
||||||
|
// Unregister from Launch Services
|
||||||
|
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||||
|
"-u", appPath)
|
||||||
|
_ = cmd.Run()
|
||||||
|
|
||||||
|
// Remove app bundle
|
||||||
|
if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove app bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
|
||||||
|
func ParseKiroURI(rawURI string) (*AuthCallback, error) {
|
||||||
|
u, err := url.Parse(rawURI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid URI: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != KiroProtocol {
|
||||||
|
return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host != KiroAuthority {
|
||||||
|
return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := u.Query()
|
||||||
|
return &AuthCallback{
|
||||||
|
Code: query.Get("code"),
|
||||||
|
State: query.Get("state"),
|
||||||
|
Error: query.Get("error"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
|
||||||
|
func GetHandlerInstructions() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return `To manually set up the Kiro protocol handler on Linux:
|
||||||
|
|
||||||
|
1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
|
||||||
|
[Desktop Entry]
|
||||||
|
Name=Kiro OAuth Handler
|
||||||
|
Exec=~/.local/bin/kiro-oauth-handler %u
|
||||||
|
Type=Application
|
||||||
|
Terminal=false
|
||||||
|
MimeType=x-scheme-handler/kiro;
|
||||||
|
|
||||||
|
2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
|
||||||
|
#!/bin/bash
|
||||||
|
URL="$1"
|
||||||
|
# ... (see generated script for full content)
|
||||||
|
|
||||||
|
3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
|
||||||
|
|
||||||
|
case "windows":
|
||||||
|
return `To manually set up the Kiro protocol handler on Windows:
|
||||||
|
|
||||||
|
1. Open Registry Editor (regedit.exe)
|
||||||
|
2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
|
||||||
|
3. Set default value to: URL:Kiro Protocol
|
||||||
|
4. Create string value "URL Protocol" with empty data
|
||||||
|
5. Create subkey: shell\open\command
|
||||||
|
6. Set default value to: "C:\path\to\handler.bat" "%1"`
|
||||||
|
|
||||||
|
case "darwin":
|
||||||
|
return `To manually set up the Kiro protocol handler on macOS:
|
||||||
|
|
||||||
|
1. Create ~/Applications/KiroOAuthHandler.app bundle
|
||||||
|
2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
|
||||||
|
3. Create executable in Contents/MacOS/
|
||||||
|
4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "Protocol handler setup is not supported on this platform."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
|
||||||
|
func SetupProtocolHandlerIfNeeded(handlerPort int) error {
|
||||||
|
if IsProtocolHandlerInstalled() {
|
||||||
|
log.Debug("Kiro protocol handler already installed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Protocol Handler Setup Required ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
|
||||||
|
fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
|
||||||
|
fmt.Println("\nInstalling protocol handler...")
|
||||||
|
|
||||||
|
if err := InstallProtocolHandler(handlerPort); err != nil {
|
||||||
|
fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
|
||||||
|
fmt.Println("\nManual setup instructions:")
|
||||||
|
fmt.Println(strings.Repeat("-", 60))
|
||||||
|
fmt.Println(GetHandlerInstructions())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Protocol handler installed successfully!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
403
internal/auth/kiro/social_auth.go
Normal file
403
internal/auth/kiro/social_auth.go
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Kiro AuthService endpoint
|
||||||
|
kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
|
||||||
|
// OAuth timeout
|
||||||
|
socialAuthTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// SocialProvider represents the social login provider.
|
||||||
|
type SocialProvider string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProviderGoogle is Google OAuth provider
|
||||||
|
ProviderGoogle SocialProvider = "Google"
|
||||||
|
// ProviderGitHub is GitHub OAuth provider
|
||||||
|
ProviderGitHub SocialProvider = "Github"
|
||||||
|
// Note: AWS Builder ID is NOT supported by Kiro's auth service.
|
||||||
|
// It only supports: Google, Github, Cognito
|
||||||
|
// AWS Builder ID must use device code flow via SSO OIDC.
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
|
||||||
|
type CreateTokenRequest struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
InvitationCode string `json:"invitation_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
|
||||||
|
type SocialTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
|
||||||
|
type RefreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SocialAuthClient handles social authentication with Kiro.
|
||||||
|
type SocialAuthClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
protocolHandler *ProtocolHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSocialAuthClient creates a new social auth client.
|
||||||
|
func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &SocialAuthClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
protocolHandler: NewProtocolHandler(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generatePKCE generates PKCE code verifier and challenge.
|
||||||
|
func generatePKCE() (verifier, challenge string, err error) {
|
||||||
|
// Generate 32 bytes of random data for verifier
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
// Generate SHA256 hash of verifier for challenge
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState generates a random state parameter.
|
||||||
|
func generateStateParam() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLoginURL constructs the Kiro OAuth login URL.
|
||||||
|
// The login endpoint expects a GET request with query parameters.
|
||||||
|
// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
|
||||||
|
// The prompt=select_account parameter forces the account selection screen even if already logged in.
|
||||||
|
func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
|
||||||
|
return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||||
|
kiroAuthServiceEndpoint,
|
||||||
|
provider,
|
||||||
|
url.QueryEscape(redirectURI),
|
||||||
|
codeChallenge,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createToken exchanges the authorization code for tokens.
|
||||||
|
func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp SocialTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSocialToken refreshes an expired social auth token.
|
||||||
|
func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp SocialTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600 // Default 1 hour
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithSocial performs OAuth login with Google.
|
||||||
|
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||||
|
providerName := string(provider)
|
||||||
|
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Setup protocol handler
|
||||||
|
fmt.Println("\nSetting up authentication...")
|
||||||
|
|
||||||
|
// Start the local callback server
|
||||||
|
handlerPort, err := c.protocolHandler.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
defer c.protocolHandler.Stop()
|
||||||
|
|
||||||
|
// Ensure protocol handler is installed and set as default
|
||||||
|
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
||||||
|
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
||||||
|
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
||||||
|
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
||||||
|
log.Debugf("kiro: protocol handler setup error: %v", err)
|
||||||
|
// Continue anyway - user might have set it up manually or select browser manually
|
||||||
|
} else {
|
||||||
|
// Force set our handler as default (prevents "Open with" dialog)
|
||||||
|
forceDefaultProtocolHandler()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Generate PKCE codes
|
||||||
|
codeVerifier, codeChallenge, err := generatePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Generate state
|
||||||
|
state, err := generateStateParam()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
||||||
|
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
||||||
|
|
||||||
|
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||||
|
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true) // Default to incognito if no config
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Open browser for user authentication
|
||||||
|
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||||
|
|
||||||
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||||
|
fmt.Println(" Please open the URL above in your browser manually.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n Waiting for authentication callback...")
|
||||||
|
|
||||||
|
// Step 6: Wait for callback
|
||||||
|
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.Error != "" {
|
||||||
|
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.State != state {
|
||||||
|
// Log state values for debugging, but don't expose in user-facing error
|
||||||
|
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
||||||
|
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.Code == "" {
|
||||||
|
return nil, fmt.Errorf("no authorization code received")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authorization received!")
|
||||||
|
|
||||||
|
// Step 7: Exchange code for tokens
|
||||||
|
fmt.Println("Exchanging code for tokens...")
|
||||||
|
|
||||||
|
tokenReq := &CreateTokenRequest{
|
||||||
|
Code: callback.Code,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
RedirectURI: KiroRedirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := c.createToken(ctx, tokenReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authentication successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
// Try to extract email from JWT access token first
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||||
|
if email == "" && isInteractiveTerminal() {
|
||||||
|
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
var err error
|
||||||
|
email, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to read account label: %v", err)
|
||||||
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: providerName,
|
||||||
|
Email: email, // JWT email or user-provided label
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login with Google.
|
||||||
|
func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
return c.LoginWithSocial(ctx, ProviderGoogle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login with GitHub.
|
||||||
|
func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
return c.LoginWithSocial(ctx, ProviderGitHub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
|
||||||
|
// This prevents the "Open with" dialog from appearing on Linux.
|
||||||
|
// On non-Linux platforms, this is a no-op as they use different mechanisms.
|
||||||
|
func forceDefaultProtocolHandler() {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return // Non-Linux platforms use different handler mechanisms
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set our handler as default using xdg-mime
|
||||||
|
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
|
||||||
|
// Returns false in CI/automated environments or when stdin is piped.
|
||||||
|
func isInteractiveTerminal() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd()))
|
||||||
|
}
|
||||||
527
internal/auth/kiro/sso_oidc.go
Normal file
527
internal/auth/kiro/sso_oidc.go
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
// Package kiro provides AWS SSO OIDC authentication for Kiro.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AWS SSO OIDC endpoints
|
||||||
|
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||||
|
|
||||||
|
// Kiro's start URL for Builder ID
|
||||||
|
builderIDStartURL = "https://view.awsapps.com/start"
|
||||||
|
|
||||||
|
// Polling interval
|
||||||
|
pollInterval = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||||
|
type SSOOIDCClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSOOIDCClient creates a new SSO OIDC client.
|
||||||
|
func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &SSOOIDCClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClientResponse from AWS SSO OIDC.
|
||||||
|
type RegisterClientResponse struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
||||||
|
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthResponse from AWS SSO OIDC.
|
||||||
|
type StartDeviceAuthResponse struct {
|
||||||
|
DeviceCode string `json:"deviceCode"`
|
||||||
|
UserCode string `json:"userCode"`
|
||||||
|
VerificationURI string `json:"verificationUri"`
|
||||||
|
VerificationURIComplete string `json:"verificationUriComplete"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenResponse from AWS SSO OIDC.
|
||||||
|
type CreateTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClient registers a new OIDC client with AWS.
|
||||||
|
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||||
|
// Generate unique client name for each registration to support multiple accounts
|
||||||
|
clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"clientName": clientName,
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterClientResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthorization starts the device authorization flow.
|
||||||
|
func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"startUrl": builderIDStartURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result StartDeviceAuthResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateToken polls for the access token after user authorization.
|
||||||
|
func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"deviceCode": deviceCode,
|
||||||
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for pending authorization
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
|
if errResp.Error == "authorization_pending" {
|
||||||
|
return nil, fmt.Errorf("authorization_pending")
|
||||||
|
}
|
||||||
|
if errResp.Error == "slow_down" {
|
||||||
|
return nil, fmt.Errorf("slow_down")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes an access token using the refresh token.
|
||||||
|
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
"grantType": "refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "builder-id",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithBuilderID performs the full device code flow for AWS Builder ID.
|
||||||
|
func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS Builder ID) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Register client
|
||||||
|
fmt.Println("\nRegistering client...")
|
||||||
|
regResp, err := c.RegisterClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||||
|
|
||||||
|
// Step 2: Start device authorization
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Show user the verification URL
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Open this URL in your browser:\n")
|
||||||
|
fmt.Printf(" %s\n", authResp.VerificationURIComplete)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI)
|
||||||
|
fmt.Printf(" And enter code: %s\n\n", authResp.UserCode)
|
||||||
|
|
||||||
|
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||||
|
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true) // Default to incognito if no config
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open browser using cross-platform browser package
|
||||||
|
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" Please open the URL manually in your browser.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
|
||||||
|
interval := pollInterval
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
browser.CloseBrowser() // Cleanup on cancel
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "authorization_pending") {
|
||||||
|
fmt.Print(".")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "slow_down") {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Close browser on error before returning
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✓ Authorization successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Get profile ARN from CodeWhisperer API
|
||||||
|
fmt.Println("Fetching profile information...")
|
||||||
|
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// Extract email from JWT access token
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "builder-id",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: regResp.ClientID,
|
||||||
|
ClientSecret: regResp.ClientSecret,
|
||||||
|
Email: email,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close browser on timeout for better UX
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
||||||
|
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
|
||||||
|
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
|
||||||
|
// Try ListProfiles API first
|
||||||
|
profileArn := c.tryListProfiles(ctx, accessToken)
|
||||||
|
if profileArn != "" {
|
||||||
|
return profileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: Try ListAvailableCustomizations
|
||||||
|
return c.tryListCustomizations(ctx, accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("ListProfiles response: %s", string(respBody))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Profiles []struct {
|
||||||
|
Arn string `json:"arn"`
|
||||||
|
} `json:"profiles"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ProfileArn != "" {
|
||||||
|
return result.ProfileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Profiles) > 0 {
|
||||||
|
return result.Profiles[0].Arn
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Customizations []struct {
|
||||||
|
Arn string `json:"arn"`
|
||||||
|
} `json:"customizations"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ProfileArn != "" {
|
||||||
|
return result.ProfileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Customizations) > 0 {
|
||||||
|
return result.Customizations[0].Arn
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
72
internal/auth/kiro/token.go
Normal file
72
internal/auth/kiro/token.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||||
|
type KiroTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||||
|
ProfileArn string `json:"profile_arn"`
|
||||||
|
// ExpiresAt is the timestamp when the token expires
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
// AuthMethod indicates the authentication method used
|
||||||
|
AuthMethod string `json:"auth_method"`
|
||||||
|
// Provider indicates the OAuth provider
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile persists the token storage to the specified file path.
|
||||||
|
func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
dir := filepath.Dir(authFilePath)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(s, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal token storage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(authFilePath, data, 0600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromFile loads token storage from the specified file path.
|
||||||
|
func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
|
||||||
|
data, err := os.ReadFile(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var storage KiroTokenStorage
|
||||||
|
if err := json.Unmarshal(data, &storage); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &storage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTokenData converts storage to KiroTokenData for API use.
|
||||||
|
func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: s.AccessToken,
|
||||||
|
RefreshToken: s.RefreshToken,
|
||||||
|
ProfileArn: s.ProfileArn,
|
||||||
|
ExpiresAt: s.ExpiresAt,
|
||||||
|
AuthMethod: s.AuthMethod,
|
||||||
|
Provider: s.Provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,14 +6,49 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
pkgbrowser "github.com/pkg/browser"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// incognitoMode controls whether to open URLs in incognito/private mode.
|
||||||
|
// This is useful for OAuth flows where you want to use a different account.
|
||||||
|
var incognitoMode bool
|
||||||
|
|
||||||
|
// lastBrowserProcess stores the last opened browser process for cleanup
|
||||||
|
var lastBrowserProcess *exec.Cmd
|
||||||
|
var browserMutex sync.Mutex
|
||||||
|
|
||||||
|
// SetIncognitoMode enables or disables incognito/private browsing mode.
|
||||||
|
func SetIncognitoMode(enabled bool) {
|
||||||
|
incognitoMode = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIncognitoMode returns whether incognito mode is enabled.
|
||||||
|
func IsIncognitoMode() bool {
|
||||||
|
return incognitoMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseBrowser closes the last opened browser process.
|
||||||
|
func CloseBrowser() error {
|
||||||
|
browserMutex.Lock()
|
||||||
|
defer browserMutex.Unlock()
|
||||||
|
|
||||||
|
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := lastBrowserProcess.Process.Kill()
|
||||||
|
lastBrowserProcess = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// OpenURL opens the specified URL in the default web browser.
|
// OpenURL opens the specified URL in the default web browser.
|
||||||
// It first attempts to use a platform-agnostic library and falls back to
|
// It uses the pkg/browser library which provides robust cross-platform support
|
||||||
// platform-specific commands if that fails.
|
// for Windows, macOS, and Linux.
|
||||||
|
// If incognito mode is enabled, it will open in a private/incognito window.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - url: The URL to open.
|
// - url: The URL to open.
|
||||||
@@ -21,16 +56,22 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - An error if the URL cannot be opened, otherwise nil.
|
// - An error if the URL cannot be opened, otherwise nil.
|
||||||
func OpenURL(url string) error {
|
func OpenURL(url string) error {
|
||||||
fmt.Printf("Attempting to open URL in browser: %s\n", url)
|
log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode)
|
||||||
|
|
||||||
// Try using the open-golang library first
|
// If incognito mode is enabled, use platform-specific incognito commands
|
||||||
err := open.Run(url)
|
if incognitoMode {
|
||||||
|
log.Debug("Using incognito mode")
|
||||||
|
return openURLIncognito(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pkg/browser for cross-platform support
|
||||||
|
err := pkgbrowser.OpenURL(url)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.Debug("Successfully opened URL using open-golang library")
|
log.Debug("Successfully opened URL using pkg/browser library")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
|
log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err)
|
||||||
|
|
||||||
// Fallback to platform-specific commands
|
// Fallback to platform-specific commands
|
||||||
return openURLPlatformSpecific(url)
|
return openURLPlatformSpecific(url)
|
||||||
@@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openURLIncognito opens a URL in incognito/private browsing mode.
|
||||||
|
// It first tries to detect the default browser and use its incognito flag.
|
||||||
|
// Falls back to a chain of known browsers if detection fails.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - url: The URL to open.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - An error if the URL cannot be opened, otherwise nil.
|
||||||
|
func openURLIncognito(url string) error {
|
||||||
|
// First, try to detect and use the default browser
|
||||||
|
if cmd := tryDefaultBrowserIncognito(url); cmd != nil {
|
||||||
|
log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:])
|
||||||
|
if err := cmd.Start(); err == nil {
|
||||||
|
storeBrowserProcess(cmd)
|
||||||
|
log.Debug("Successfully opened URL in default browser's incognito mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("Failed to start default browser, trying fallback chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to known browser chain
|
||||||
|
cmd := tryFallbackBrowsersIncognito(url)
|
||||||
|
if cmd == nil {
|
||||||
|
log.Warn("No browser with incognito support found, falling back to normal mode")
|
||||||
|
return openURLPlatformSpecific(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:])
|
||||||
|
err := cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err)
|
||||||
|
return openURLPlatformSpecific(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
storeBrowserProcess(cmd)
|
||||||
|
log.Debug("Successfully opened URL in incognito/private mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeBrowserProcess safely stores the browser process for later cleanup.
|
||||||
|
func storeBrowserProcess(cmd *exec.Cmd) {
|
||||||
|
browserMutex.Lock()
|
||||||
|
lastBrowserProcess = cmd
|
||||||
|
browserMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserIncognito attempts to detect the default browser and return
|
||||||
|
// an exec.Cmd configured with the appropriate incognito flag.
|
||||||
|
func tryDefaultBrowserIncognito(url string) *exec.Cmd {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return tryDefaultBrowserMacOS(url)
|
||||||
|
case "windows":
|
||||||
|
return tryDefaultBrowserWindows(url)
|
||||||
|
case "linux":
|
||||||
|
return tryDefaultBrowserLinux(url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserMacOS detects the default browser on macOS.
|
||||||
|
func tryDefaultBrowserMacOS(url string) *exec.Cmd {
|
||||||
|
// Try to get default browser from Launch Services
|
||||||
|
out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
output := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Parse the output to find the http/https handler
|
||||||
|
if containsBrowserID(output, "com.google.chrome") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if containsBrowserID(output, "org.mozilla.firefox") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if containsBrowserID(output, "com.apple.safari") {
|
||||||
|
browserName = "safari"
|
||||||
|
} else if containsBrowserID(output, "com.brave.browser") {
|
||||||
|
browserName = "brave"
|
||||||
|
} else if containsBrowserID(output, "com.microsoft.edgemac") {
|
||||||
|
browserName = "edge"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createMacOSIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsBrowserID checks if the LaunchServices output contains a browser ID.
|
||||||
|
func containsBrowserID(output, bundleID string) bool {
|
||||||
|
return strings.Contains(output, bundleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers.
|
||||||
|
func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
// Try direct path first
|
||||||
|
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||||
|
if _, err := exec.LookPath(chromePath); err == nil {
|
||||||
|
return exec.Command(chromePath, "--incognito", url)
|
||||||
|
}
|
||||||
|
return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url)
|
||||||
|
case "firefox":
|
||||||
|
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
|
||||||
|
case "safari":
|
||||||
|
// Safari doesn't have CLI incognito, try AppleScript
|
||||||
|
return tryAppleScriptSafariPrivate(url)
|
||||||
|
case "brave":
|
||||||
|
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
|
||||||
|
case "edge":
|
||||||
|
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript.
|
||||||
|
func tryAppleScriptSafariPrivate(url string) *exec.Cmd {
|
||||||
|
// AppleScript to open a new private window in Safari
|
||||||
|
script := fmt.Sprintf(`
|
||||||
|
tell application "Safari"
|
||||||
|
activate
|
||||||
|
tell application "System Events"
|
||||||
|
keystroke "n" using {command down, shift down}
|
||||||
|
delay 0.5
|
||||||
|
end tell
|
||||||
|
set URL of document 1 to "%s"
|
||||||
|
end tell
|
||||||
|
`, url)
|
||||||
|
|
||||||
|
cmd := exec.Command("osascript", "-e", script)
|
||||||
|
// Test if this approach works by checking if Safari is available
|
||||||
|
if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil {
|
||||||
|
log.Debug("Safari not found, AppleScript private window not available")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debug("Attempting Safari private window via AppleScript")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserWindows detects the default browser on Windows via registry.
|
||||||
|
func tryDefaultBrowserWindows(url string) *exec.Cmd {
|
||||||
|
// Query registry for default browser
|
||||||
|
out, err := exec.Command("reg", "query",
|
||||||
|
`HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`,
|
||||||
|
"/v", "ProgId").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
output := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Map ProgId to browser name
|
||||||
|
if strings.Contains(output, "ChromeHTML") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if strings.Contains(output, "FirefoxURL") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if strings.Contains(output, "MSEdgeHTM") {
|
||||||
|
browserName = "edge"
|
||||||
|
} else if strings.Contains(output, "BraveHTML") {
|
||||||
|
browserName = "brave"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createWindowsIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers.
|
||||||
|
func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
paths := []string{
|
||||||
|
"chrome",
|
||||||
|
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
|
||||||
|
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "firefox":
|
||||||
|
if path, err := exec.LookPath("firefox"); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
case "edge":
|
||||||
|
paths := []string{
|
||||||
|
"msedge",
|
||||||
|
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "brave":
|
||||||
|
paths := []string{
|
||||||
|
`C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`,
|
||||||
|
`C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings.
|
||||||
|
func tryDefaultBrowserLinux(url string) *exec.Cmd {
|
||||||
|
out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
desktop := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Map .desktop file to browser name
|
||||||
|
if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if strings.Contains(desktop, "firefox") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if strings.Contains(desktop, "chromium") {
|
||||||
|
browserName = "chromium"
|
||||||
|
} else if strings.Contains(desktop, "brave") {
|
||||||
|
browserName = "brave"
|
||||||
|
} else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") {
|
||||||
|
browserName = "edge"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createLinuxIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers.
|
||||||
|
func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
paths := []string{"google-chrome", "google-chrome-stable"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "firefox":
|
||||||
|
paths := []string{"firefox", "firefox-esr"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "chromium":
|
||||||
|
paths := []string{"chromium", "chromium-browser"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "brave":
|
||||||
|
if path, err := exec.LookPath("brave-browser"); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
case "edge":
|
||||||
|
if path, err := exec.LookPath("microsoft-edge"); err == nil {
|
||||||
|
return exec.Command(path, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback.
|
||||||
|
func tryFallbackBrowsersIncognito(url string) *exec.Cmd {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return tryFallbackBrowsersMacOS(url)
|
||||||
|
case "windows":
|
||||||
|
return tryFallbackBrowsersWindows(url)
|
||||||
|
case "linux":
|
||||||
|
return tryFallbackBrowsersLinuxChain(url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersMacOS tries known browsers on macOS.
|
||||||
|
func tryFallbackBrowsersMacOS(url string) *exec.Cmd {
|
||||||
|
// Try Chrome
|
||||||
|
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||||
|
if _, err := exec.LookPath(chromePath); err == nil {
|
||||||
|
return exec.Command(chromePath, "--incognito", url)
|
||||||
|
}
|
||||||
|
// Try Firefox
|
||||||
|
if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
|
||||||
|
}
|
||||||
|
// Try Brave
|
||||||
|
if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
|
||||||
|
}
|
||||||
|
// Try Edge
|
||||||
|
if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
|
||||||
|
}
|
||||||
|
// Last resort: try Safari with AppleScript
|
||||||
|
if cmd := tryAppleScriptSafariPrivate(url); cmd != nil {
|
||||||
|
log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersWindows tries known browsers on Windows.
|
||||||
|
func tryFallbackBrowsersWindows(url string) *exec.Cmd {
|
||||||
|
// Chrome
|
||||||
|
chromePaths := []string{
|
||||||
|
"chrome",
|
||||||
|
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
|
||||||
|
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range chromePaths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Firefox
|
||||||
|
if path, err := exec.LookPath("firefox"); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
// Edge (usually available on Windows 10+)
|
||||||
|
edgePaths := []string{
|
||||||
|
"msedge",
|
||||||
|
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range edgePaths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersLinuxChain tries known browsers on Linux.
|
||||||
|
func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd {
|
||||||
|
type browserConfig struct {
|
||||||
|
name string
|
||||||
|
flag string
|
||||||
|
}
|
||||||
|
browsers := []browserConfig{
|
||||||
|
{"google-chrome", "--incognito"},
|
||||||
|
{"google-chrome-stable", "--incognito"},
|
||||||
|
{"chromium", "--incognito"},
|
||||||
|
{"chromium-browser", "--incognito"},
|
||||||
|
{"firefox", "--private-window"},
|
||||||
|
{"firefox-esr", "--private-window"},
|
||||||
|
{"brave-browser", "--incognito"},
|
||||||
|
{"microsoft-edge", "--inprivate"},
|
||||||
|
}
|
||||||
|
for _, b := range browsers {
|
||||||
|
if path, err := exec.LookPath(b.name); err == nil {
|
||||||
|
return exec.Command(path, b.flag, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// IsAvailable checks if the system has a command available to open a web browser.
|
// IsAvailable checks if the system has a command available to open a web browser.
|
||||||
// It verifies the presence of necessary commands for the current operating system.
|
// It verifies the presence of necessary commands for the current operating system.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - true if a browser can be opened, false otherwise.
|
// - true if a browser can be opened, false otherwise.
|
||||||
func IsAvailable() bool {
|
func IsAvailable() bool {
|
||||||
// First check if open-golang can work
|
|
||||||
testErr := open.Run("about:blank")
|
|
||||||
if testErr == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check platform-specific commands
|
// Check platform-specific commands
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewQwenAuthenticator(),
|
sdkAuth.NewQwenAuthenticator(),
|
||||||
sdkAuth.NewIFlowAuthenticator(),
|
sdkAuth.NewIFlowAuthenticator(),
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
|
|||||||
160
internal/cmd/kiro_login.go
Normal file
160
internal/cmd/kiro_login.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoKiroLogin triggers the Kiro authentication flow with Google OAuth.
|
||||||
|
// This is the default login method (same as --kiro-google-login).
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including Prompt field
|
||||||
|
func DoKiroLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
// Use Google login as default
|
||||||
|
DoKiroGoogleLogin(cfg, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including prompts
|
||||||
|
func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use KiroAuthenticator with Google login
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro Google authentication failed: %v", err)
|
||||||
|
fmt.Println("\nTroubleshooting:")
|
||||||
|
fmt.Println("1. Make sure the protocol handler is installed")
|
||||||
|
fmt.Println("2. Complete the Google login in the browser")
|
||||||
|
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro Google authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID.
|
||||||
|
// This uses the device code flow for AWS SSO OIDC authentication.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including prompts
|
||||||
|
func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use KiroAuthenticator with AWS Builder ID login (device code flow)
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro AWS authentication failed: %v", err)
|
||||||
|
fmt.Println("\nTroubleshooting:")
|
||||||
|
fmt.Println("1. Make sure you have an AWS Builder ID")
|
||||||
|
fmt.Println("2. Complete the authorization in the browser")
|
||||||
|
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro AWS authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroImport imports Kiro token from Kiro IDE's token file.
|
||||||
|
// This is useful for users who have already logged in via Kiro IDE
|
||||||
|
// and want to use the same credentials in CLI Proxy API.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options (currently unused for import)
|
||||||
|
func DoKiroImport(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use ImportFromKiroIDE instead of Login
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro token import failed: %v", err)
|
||||||
|
fmt.Println("\nMake sure you have logged in to Kiro IDE first:")
|
||||||
|
fmt.Println("1. Open Kiro IDE")
|
||||||
|
fmt.Println("2. Click 'Sign in with Google' (or GitHub)")
|
||||||
|
fmt.Println("3. Complete the login process")
|
||||||
|
fmt.Println("4. Run this command again")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the imported auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Imported as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro token import successful!")
|
||||||
|
}
|
||||||
@@ -65,20 +65,20 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||||
record, errLogin := authenticator.Login(ctx, cfg, loginOpts)
|
record, errLogin := authenticator.Login(ctx, cfg, loginOpts)
|
||||||
if errLogin != nil {
|
if errLogin != nil {
|
||||||
log.Fatalf("Gemini authentication failed: %v", errLogin)
|
log.Errorf("Gemini authentication failed: %v", errLogin)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage)
|
storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage)
|
||||||
if !okStorage || storage == nil {
|
if !okStorage || storage == nil {
|
||||||
log.Fatal("Gemini authentication failed: unsupported token storage")
|
log.Error("Gemini authentication failed: unsupported token storage")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
geminiAuth := gemini.NewGeminiAuth()
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
|
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
|
||||||
if errClient != nil {
|
if errClient != nil {
|
||||||
log.Fatalf("Gemini authentication failed: %v", errClient)
|
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
|
|
||||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||||
if errProjects != nil {
|
if errProjects != nil {
|
||||||
log.Fatalf("Failed to get project list: %v", errProjects)
|
log.Errorf("Failed to get project list: %v", errProjects)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,11 +98,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
||||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||||
if errSelection != nil {
|
if errSelection != nil {
|
||||||
log.Fatalf("Invalid project selection: %v", errSelection)
|
log.Errorf("Invalid project selection: %v", errSelection)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(projectSelections) == 0 {
|
if len(projectSelections) == 0 {
|
||||||
log.Fatal("No project selected; aborting login.")
|
log.Error("No project selected; aborting login.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
showProjectSelectionHelp(storage.Email, projects)
|
showProjectSelectionHelp(storage.Email, projects)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Fatalf("Failed to complete user setup: %v", errSetup)
|
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
finalID := strings.TrimSpace(storage.ProjectID)
|
finalID := strings.TrimSpace(storage.ProjectID)
|
||||||
@@ -133,11 +133,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
for _, pid := range activatedProjects {
|
for _, pid := range activatedProjects {
|
||||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
|
||||||
if errCheck != nil {
|
if errCheck != nil {
|
||||||
log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
|
log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
|
log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -153,7 +153,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
|
|
||||||
savedPath, errSave := store.Save(ctx, record)
|
savedPath, errSave := store.Save(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("Failed to save token to file: %v", errSave)
|
log.Errorf("Failed to save token to file: %v", errSave)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -555,6 +555,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
return false, fmt.Errorf("project activation required: %s", errMessage)
|
return false, fmt.Errorf("project activation required: %s", errMessage)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
@@ -45,12 +45,13 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
|
|||||||
|
|
||||||
service, err := builder.Build()
|
service, err := builder.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to build proxy service: %v", err)
|
log.Errorf("failed to build proxy service: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.Run(runCtx)
|
err = service.Run(runCtx)
|
||||||
if err != nil && !errors.Is(err, context.Canceled) {
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
log.Fatalf("proxy service exited with error: %v", err)
|
log.Errorf("proxy service exited with error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,30 +29,30 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
}
|
}
|
||||||
rawPath := strings.TrimSpace(keyPath)
|
rawPath := strings.TrimSpace(keyPath)
|
||||||
if rawPath == "" {
|
if rawPath == "" {
|
||||||
log.Fatalf("vertex-import: missing service account key path")
|
log.Errorf("vertex-import: missing service account key path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
data, errRead := os.ReadFile(rawPath)
|
data, errRead := os.ReadFile(rawPath)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
log.Fatalf("vertex-import: read file failed: %v", errRead)
|
log.Errorf("vertex-import: read file failed: %v", errRead)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var sa map[string]any
|
var sa map[string]any
|
||||||
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
|
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
|
||||||
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
|
log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Validate and normalize private_key before saving
|
// Validate and normalize private_key before saving
|
||||||
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
|
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
|
||||||
if errFix != nil {
|
if errFix != nil {
|
||||||
log.Fatalf("vertex-import: %v", errFix)
|
log.Errorf("vertex-import: %v", errFix)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sa = normalizedSA
|
sa = normalizedSA
|
||||||
email, _ := sa["client_email"].(string)
|
email, _ := sa["client_email"].(string)
|
||||||
projectID, _ := sa["project_id"].(string)
|
projectID, _ := sa["project_id"].(string)
|
||||||
if strings.TrimSpace(projectID) == "" {
|
if strings.TrimSpace(projectID) == "" {
|
||||||
log.Fatalf("vertex-import: project_id missing in service account json")
|
log.Errorf("vertex-import: project_id missing in service account json")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(email) == "" {
|
if strings.TrimSpace(email) == "" {
|
||||||
@@ -92,7 +92,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
|
|||||||
}
|
}
|
||||||
path, errSave := store.Save(context.Background(), record)
|
path, errSave := store.Save(context.Background(), record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Fatalf("vertex-import: save credential failed: %v", errSave)
|
log.Errorf("vertex-import: save credential failed: %v", errSave)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Printf("Vertex credentials imported: %s\n", path)
|
fmt.Printf("Vertex credentials imported: %s\n", path)
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ import (
|
|||||||
// Config represents the application's configuration, loaded from a YAML file.
|
// Config represents the application's configuration, loaded from a YAML file.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
config.SDKConfig `yaml:",inline"`
|
config.SDKConfig `yaml:",inline"`
|
||||||
|
// Host is the network host/interface on which the API server will bind.
|
||||||
|
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||||
|
Host string `yaml:"host" json:"-"`
|
||||||
// Port is the network port on which the API server will listen.
|
// Port is the network port on which the API server will listen.
|
||||||
Port int `yaml:"port" json:"-"`
|
Port int `yaml:"port" json:"-"`
|
||||||
|
|
||||||
@@ -58,6 +61,9 @@ type Config struct {
|
|||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
|
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
||||||
|
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
||||||
|
|
||||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||||
|
|
||||||
@@ -80,6 +86,11 @@ type Config struct {
|
|||||||
// Payload defines default and override rules for provider payload parameters.
|
// Payload defines default and override rules for provider payload parameters.
|
||||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||||
|
|
||||||
|
// IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode.
|
||||||
|
// This is useful when you want to login with a different account without logging out
|
||||||
|
// from your current session. Default: false.
|
||||||
|
IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"`
|
||||||
|
|
||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,6 +154,10 @@ type AmpCode struct {
|
|||||||
// When Amp requests a model that isn't available locally, these mappings
|
// When Amp requests a model that isn't available locally, these mappings
|
||||||
// allow routing to an alternative model that IS available.
|
// allow routing to an alternative model that IS available.
|
||||||
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
|
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
|
||||||
|
|
||||||
|
// ForceModelMappings when true, model mappings take precedence over local API keys.
|
||||||
|
// When false (default), local API keys are used first if available.
|
||||||
|
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||||
@@ -240,6 +255,31 @@ type GeminiKey struct {
|
|||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
|
||||||
|
type KiroKey struct {
|
||||||
|
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
|
||||||
|
TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"`
|
||||||
|
|
||||||
|
// AccessToken is the OAuth access token for direct configuration.
|
||||||
|
AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"`
|
||||||
|
|
||||||
|
// RefreshToken is the OAuth refresh token for token renewal.
|
||||||
|
RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"`
|
||||||
|
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN.
|
||||||
|
ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"`
|
||||||
|
|
||||||
|
// Region is the AWS region (default: us-east-1).
|
||||||
|
Region string `yaml:"region,omitempty" json:"region,omitempty"`
|
||||||
|
|
||||||
|
// ProxyURL optionally overrides the global proxy for this configuration.
|
||||||
|
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||||
|
|
||||||
|
// AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
|
||||||
|
// Leave empty to let API use defaults. Different values may inject different system prompts.
|
||||||
|
AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||||
type OpenAICompatibility struct {
|
type OpenAICompatibility struct {
|
||||||
@@ -316,10 +356,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Unmarshal the YAML data into the Config struct.
|
// Unmarshal the YAML data into the Config struct.
|
||||||
var cfg Config
|
var cfg Config
|
||||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||||
|
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||||
cfg.LoggingToFile = false
|
cfg.LoggingToFile = false
|
||||||
cfg.UsageStatisticsEnabled = false
|
cfg.UsageStatisticsEnabled = false
|
||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||||
|
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
if optional {
|
if optional {
|
||||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||||
@@ -370,6 +412,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Claude key headers
|
// Sanitize Claude key headers
|
||||||
cfg.SanitizeClaudeKeys()
|
cfg.SanitizeClaudeKeys()
|
||||||
|
|
||||||
|
// Sanitize Kiro keys: trim whitespace from credential fields
|
||||||
|
cfg.SanitizeKiroKeys()
|
||||||
|
|
||||||
// Sanitize OpenAI compatibility providers: drop entries without base-url
|
// Sanitize OpenAI compatibility providers: drop entries without base-url
|
||||||
cfg.SanitizeOpenAICompatibility()
|
cfg.SanitizeOpenAICompatibility()
|
||||||
|
|
||||||
@@ -446,6 +491,22 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeKiroKeys trims whitespace from Kiro credential fields.
|
||||||
|
func (cfg *Config) SanitizeKiroKeys() {
|
||||||
|
if cfg == nil || len(cfg.KiroKey) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range cfg.KiroKey {
|
||||||
|
entry := &cfg.KiroKey[i]
|
||||||
|
entry.TokenFile = strings.TrimSpace(entry.TokenFile)
|
||||||
|
entry.AccessToken = strings.TrimSpace(entry.AccessToken)
|
||||||
|
entry.RefreshToken = strings.TrimSpace(entry.RefreshToken)
|
||||||
|
entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
|
||||||
|
entry.Region = strings.TrimSpace(entry.Region)
|
||||||
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
||||||
func (cfg *Config) SanitizeGeminiKeys() {
|
func (cfg *Config) SanitizeGeminiKeys() {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
|
|||||||
@@ -24,4 +24,7 @@ const (
|
|||||||
|
|
||||||
// Antigravity represents the Antigravity response format identifier.
|
// Antigravity represents the Antigravity response format identifier.
|
||||||
Antigravity = "antigravity"
|
Antigravity = "antigravity"
|
||||||
|
|
||||||
|
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
|
||||||
|
Kiro = "kiro"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ type Content struct {
|
|||||||
// Part represents a distinct piece of content within a message.
|
// Part represents a distinct piece of content within a message.
|
||||||
// A part can be text, inline data (like an image), a function call, or a function response.
|
// A part can be text, inline data (like an image), a function call, or a function response.
|
||||||
type Part struct {
|
type Part struct {
|
||||||
|
Thought bool `json:"thought,omitempty"`
|
||||||
|
|
||||||
// Text contains plain text content.
|
// Text contains plain text content.
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
|
|
||||||
@@ -85,6 +87,9 @@ type InlineData struct {
|
|||||||
// FunctionCall represents a tool call requested by the model.
|
// FunctionCall represents a tool call requested by the model.
|
||||||
// It includes the function name and its arguments that the model wants to execute.
|
// It includes the function name and its arguments that the model wants to execute.
|
||||||
type FunctionCall struct {
|
type FunctionCall struct {
|
||||||
|
// ID is the identifier of the function to be called.
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
|
||||||
// Name is the identifier of the function to be called.
|
// Name is the identifier of the function to be called.
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
@@ -95,6 +100,9 @@ type FunctionCall struct {
|
|||||||
// FunctionResponse represents the result of a tool execution.
|
// FunctionResponse represents the result of a tool execution.
|
||||||
// This is sent back to the model after a tool call has been processed.
|
// This is sent back to the model after a tool call has been processed.
|
||||||
type FunctionResponse struct {
|
type FunctionResponse struct {
|
||||||
|
// ID is the identifier of the function to be called.
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
|
||||||
// Name is the identifier of the function that was called.
|
// Name is the identifier of the function that was called.
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||||
|
|
||||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||||
// using logrus. It captures request details including method, path, status code, latency,
|
// using logrus. It captures request details including method, path, status code, latency,
|
||||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
// client IP, and any error messages, formatting them in a Gin-style log format.
|
||||||
@@ -28,6 +30,10 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
|
if shouldSkipGinRequestLogging(c) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if raw != "" {
|
if raw != "" {
|
||||||
path = path + "?" + raw
|
path = path + "?" + raw
|
||||||
}
|
}
|
||||||
@@ -77,3 +83,24 @@ func GinLogrusRecovery() gin.HandlerFunc {
|
|||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SkipGinRequestLogging marks the provided Gin context so that GinLogrusLogger
|
||||||
|
// will skip emitting a log line for the associated request.
|
||||||
|
func SkipGinRequestLogging(c *gin.Context) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(skipGinLogKey, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSkipGinRequestLogging(c *gin.Context) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
val, exists := c.Get(skipGinLogKey)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
flag, ok := val.(bool)
|
||||||
|
return ok && flag
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
|
|
||||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||||
message := strings.TrimRight(entry.Message, "\r\n")
|
message := strings.TrimRight(entry.Message, "\r\n")
|
||||||
|
|
||||||
var formatted string
|
// Handle nil Caller (can happen with some log entries)
|
||||||
|
callerFile := "unknown"
|
||||||
|
callerLine := 0
|
||||||
if entry.Caller != nil {
|
if entry.Caller != nil {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
callerFile = filepath.Base(entry.Caller.File)
|
||||||
} else {
|
callerLine = entry.Caller.Line
|
||||||
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
|
||||||
buffer.WriteString(formatted)
|
buffer.WriteString(formatted)
|
||||||
|
|
||||||
return buffer.Bytes(), nil
|
return buffer.Bytes(), nil
|
||||||
@@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
func SetupBaseLogger() {
|
func SetupBaseLogger() {
|
||||||
setupOnce.Do(func() {
|
setupOnce.Do(func() {
|
||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
log.SetReportCaller(true)
|
log.SetReportCaller(true)
|
||||||
log.SetFormatter(&LogFormatter{})
|
log.SetFormatter(&LogFormatter{})
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
)
|
)
|
||||||
@@ -83,6 +84,26 @@ type StreamingLogWriter interface {
|
|||||||
// - error: An error if writing fails, nil otherwise
|
// - error: An error if writing fails, nil otherwise
|
||||||
WriteStatus(status int, headers map[string][]string) error
|
WriteStatus(status int, headers map[string][]string) error
|
||||||
|
|
||||||
|
// WriteAPIRequest writes the upstream API request details to the log.
|
||||||
|
// This should be called before WriteStatus to maintain proper log ordering.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if writing fails, nil otherwise
|
||||||
|
WriteAPIRequest(apiRequest []byte) error
|
||||||
|
|
||||||
|
// WriteAPIResponse writes the upstream API response details to the log.
|
||||||
|
// This should be called after the streaming response is complete.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiResponse: The API response data
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if writing fails, nil otherwise
|
||||||
|
WriteAPIResponse(apiResponse []byte) error
|
||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
@@ -247,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
|||||||
|
|
||||||
// Create streaming writer
|
// Create streaming writer
|
||||||
writer := &FileStreamingLogWriter{
|
writer := &FileStreamingLogWriter{
|
||||||
file: file,
|
file: file,
|
||||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||||
closeChan: make(chan struct{}),
|
closeChan: make(chan struct{}),
|
||||||
errorChan: make(chan error, 1),
|
errorChan: make(chan error, 1),
|
||||||
|
bufferedChunks: &bytes.Buffer{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start async writer goroutine
|
// Start async writer goroutine
|
||||||
@@ -603,6 +625,7 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
|||||||
var content strings.Builder
|
var content strings.Builder
|
||||||
|
|
||||||
content.WriteString("=== REQUEST INFO ===\n")
|
content.WriteString("=== REQUEST INFO ===\n")
|
||||||
|
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
|
||||||
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||||
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
@@ -626,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
|||||||
|
|
||||||
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||||
// It handles asynchronous writing of streaming response chunks to a file.
|
// It handles asynchronous writing of streaming response chunks to a file.
|
||||||
|
// All data is buffered and written in the correct order when Close is called.
|
||||||
type FileStreamingLogWriter struct {
|
type FileStreamingLogWriter struct {
|
||||||
// file is the file where log data is written.
|
// file is the file where log data is written.
|
||||||
file *os.File
|
file *os.File
|
||||||
|
|
||||||
// chunkChan is a channel for receiving response chunks to write.
|
// chunkChan is a channel for receiving response chunks to buffer.
|
||||||
chunkChan chan []byte
|
chunkChan chan []byte
|
||||||
|
|
||||||
// closeChan is a channel for signaling when the writer is closed.
|
// closeChan is a channel for signaling when the writer is closed.
|
||||||
@@ -639,8 +663,23 @@ type FileStreamingLogWriter struct {
|
|||||||
// errorChan is a channel for reporting errors during writing.
|
// errorChan is a channel for reporting errors during writing.
|
||||||
errorChan chan error
|
errorChan chan error
|
||||||
|
|
||||||
// statusWritten indicates whether the response status has been written.
|
// bufferedChunks stores the response chunks in order.
|
||||||
|
bufferedChunks *bytes.Buffer
|
||||||
|
|
||||||
|
// responseStatus stores the HTTP status code.
|
||||||
|
responseStatus int
|
||||||
|
|
||||||
|
// statusWritten indicates whether a non-zero status was recorded.
|
||||||
statusWritten bool
|
statusWritten bool
|
||||||
|
|
||||||
|
// responseHeaders stores the response headers.
|
||||||
|
responseHeaders map[string][]string
|
||||||
|
|
||||||
|
// apiRequest stores the upstream API request data.
|
||||||
|
apiRequest []byte
|
||||||
|
|
||||||
|
// apiResponse stores the upstream API response data.
|
||||||
|
apiResponse []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
||||||
@@ -664,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteStatus writes the response status and headers to the log.
|
// WriteStatus buffers the response status and headers for later writing.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - status: The response status code
|
// - status: The response status code
|
||||||
// - headers: The response headers
|
// - headers: The response headers
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if writing fails, nil otherwise
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
|
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
|
||||||
if w.file == nil || w.statusWritten {
|
if status == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var content strings.Builder
|
w.responseStatus = status
|
||||||
content.WriteString("========================================\n")
|
if headers != nil {
|
||||||
content.WriteString("=== RESPONSE ===\n")
|
w.responseHeaders = make(map[string][]string, len(headers))
|
||||||
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
for key, values := range headers {
|
||||||
|
headerValues := make([]string, len(values))
|
||||||
for key, values := range headers {
|
copy(headerValues, values)
|
||||||
for _, value := range values {
|
w.responseHeaders[key] = headerValues
|
||||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
content.WriteString("\n")
|
w.statusWritten = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
_, err := w.file.WriteString(content.String())
|
// WriteAPIRequest buffers the upstream API request details for later writing.
|
||||||
if err == nil {
|
//
|
||||||
w.statusWritten = true
|
// Parameters:
|
||||||
|
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
|
func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error {
|
||||||
|
if len(apiRequest) == 0 {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
w.apiRequest = bytes.Clone(apiRequest)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteAPIResponse buffers the upstream API response details for later writing.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiResponse: The API response data
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil (buffering cannot fail)
|
||||||
|
func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
||||||
|
if len(apiResponse) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
w.apiResponse = bytes.Clone(apiResponse)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close finalizes the log file and cleans up resources.
|
// Close finalizes the log file and cleans up resources.
|
||||||
|
// It writes all buffered data to the file in the correct order:
|
||||||
|
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - error: An error if closing fails, nil otherwise
|
// - error: An error if closing fails, nil otherwise
|
||||||
@@ -705,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error {
|
|||||||
close(w.chunkChan)
|
close(w.chunkChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for async writer to finish
|
// Wait for async writer to finish buffering chunks
|
||||||
if w.closeChan != nil {
|
if w.closeChan != nil {
|
||||||
<-w.closeChan
|
<-w.closeChan
|
||||||
w.chunkChan = nil
|
w.chunkChan = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if w.file != nil {
|
if w.file == nil {
|
||||||
return w.file.Close()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Write all content in the correct order
|
||||||
|
var content strings.Builder
|
||||||
|
|
||||||
|
// 1. Write API REQUEST section
|
||||||
|
if len(w.apiRequest) > 0 {
|
||||||
|
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
|
||||||
|
content.Write(w.apiRequest)
|
||||||
|
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== API REQUEST ===\n")
|
||||||
|
content.Write(w.apiRequest)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Write API RESPONSE section
|
||||||
|
if len(w.apiResponse) > 0 {
|
||||||
|
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
|
||||||
|
content.Write(w.apiResponse)
|
||||||
|
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.WriteString("=== API RESPONSE ===\n")
|
||||||
|
content.Write(w.apiResponse)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Write RESPONSE section (status, headers, buffered chunks)
|
||||||
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
|
if w.statusWritten {
|
||||||
|
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, values := range w.responseHeaders {
|
||||||
|
for _, value := range values {
|
||||||
|
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
// Write buffered response body chunks
|
||||||
|
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
|
||||||
|
content.Write(w.bufferedChunks.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the complete content to file
|
||||||
|
if _, err := w.file.WriteString(content.String()); err != nil {
|
||||||
|
_ = w.file.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// asyncWriter runs in a goroutine to handle async chunk writing.
|
// asyncWriter runs in a goroutine to buffer chunks from the channel.
|
||||||
// It continuously reads chunks from the channel and writes them to the file.
|
// It continuously reads chunks from the channel and buffers them for later writing.
|
||||||
func (w *FileStreamingLogWriter) asyncWriter() {
|
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||||
defer close(w.closeChan)
|
defer close(w.closeChan)
|
||||||
|
|
||||||
for chunk := range w.chunkChan {
|
for chunk := range w.chunkChan {
|
||||||
if w.file != nil {
|
if w.bufferedChunks != nil {
|
||||||
_, _ = w.file.Write(chunk)
|
w.bufferedChunks.Write(chunk)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -752,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteAPIRequest is a no-op implementation that does nothing and always returns nil.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiRequest: The API request data (ignored)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteAPIResponse is a no-op implementation that does nothing and always returns nil.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiResponse: The API response data (ignored)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: Always returns nil
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Close is a no-op implementation that does nothing and always returns nil.
|
// Close is a no-op implementation that does nothing and always returns nil.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
|
|||||||
@@ -693,8 +693,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Low",
|
DisplayName: "GPT 5.1 Nothink",
|
||||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -719,8 +719,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Medium",
|
DisplayName: "GPT 5.1 Medium",
|
||||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -732,8 +732,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 High",
|
DisplayName: "GPT 5.1 High",
|
||||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -745,8 +745,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex",
|
DisplayName: "GPT 5.1 Codex",
|
||||||
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -758,8 +758,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex Low",
|
DisplayName: "GPT 5.1 Codex Low",
|
||||||
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -771,8 +771,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex Medium",
|
DisplayName: "GPT 5.1 Codex Medium",
|
||||||
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -784,8 +784,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex High",
|
DisplayName: "GPT 5.1 Codex High",
|
||||||
Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
|
Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -797,8 +797,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex Mini",
|
DisplayName: "GPT 5.1 Codex Mini",
|
||||||
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -810,8 +810,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex Mini Medium",
|
DisplayName: "GPT 5.1 Codex Mini Medium",
|
||||||
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -823,8 +823,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-2025-11-12",
|
Version: "gpt-5.1-2025-11-12",
|
||||||
DisplayName: "GPT 5 Codex Mini High",
|
DisplayName: "GPT 5.1 Codex Mini High",
|
||||||
Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
|
Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -837,8 +837,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-max",
|
Version: "gpt-5.1-max",
|
||||||
DisplayName: "GPT 5 Codex Max",
|
DisplayName: "GPT 5.1 Codex Max",
|
||||||
Description: "Stable version of GPT 5 Codex Max",
|
Description: "Stable version of GPT 5.1 Codex Max",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -850,8 +850,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-max",
|
Version: "gpt-5.1-max",
|
||||||
DisplayName: "GPT 5 Codex Max Low",
|
DisplayName: "GPT 5.1 Codex Max Low",
|
||||||
Description: "Stable version of GPT 5 Codex Max Low",
|
Description: "Stable version of GPT 5.1 Codex Max Low",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -863,8 +863,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-max",
|
Version: "gpt-5.1-max",
|
||||||
DisplayName: "GPT 5 Codex Max Medium",
|
DisplayName: "GPT 5.1 Codex Max Medium",
|
||||||
Description: "Stable version of GPT 5 Codex Max Medium",
|
Description: "Stable version of GPT 5.1 Codex Max Medium",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -876,8 +876,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-max",
|
Version: "gpt-5.1-max",
|
||||||
DisplayName: "GPT 5 Codex Max High",
|
DisplayName: "GPT 5.1 Codex Max High",
|
||||||
Description: "Stable version of GPT 5 Codex Max High",
|
Description: "Stable version of GPT 5.1 Codex Max High",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -889,8 +889,8 @@ func GetOpenAIModels() []*ModelInfo {
|
|||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Type: "openai",
|
Type: "openai",
|
||||||
Version: "gpt-5.1-max",
|
Version: "gpt-5.1-max",
|
||||||
DisplayName: "GPT 5 Codex Max XHigh",
|
DisplayName: "GPT 5.1 Codex Max XHigh",
|
||||||
Description: "Stable version of GPT 5 Codex Max XHigh",
|
Description: "Stable version of GPT 5.1 Codex Max XHigh",
|
||||||
ContextLength: 400000,
|
ContextLength: 400000,
|
||||||
MaxCompletionTokens: 128000,
|
MaxCompletionTokens: 128000,
|
||||||
SupportedParameters: []string{"tools"},
|
SupportedParameters: []string{"tools"},
|
||||||
@@ -944,7 +944,6 @@ func GetQwenModels() []*ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||||
|
|
||||||
func GetIFlowModels() []*ModelInfo {
|
func GetIFlowModels() []*ModelInfo {
|
||||||
entries := []struct {
|
entries := []struct {
|
||||||
ID string
|
ID string
|
||||||
@@ -987,6 +986,28 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AntigravityModelConfig captures static antigravity model overrides, including
|
||||||
|
// Thinking budget limits and provider max completion tokens.
|
||||||
|
type AntigravityModelConfig struct {
|
||||||
|
Thinking *ThinkingSupport
|
||||||
|
MaxCompletionTokens int
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
|
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
|
||||||
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
|
return map[string]*AntigravityModelConfig{
|
||||||
|
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||||
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||||
|
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||||
|
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||||
|
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||||
|
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
@@ -1170,3 +1191,150 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||||
|
func GetKiroModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-opus-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Opus 4.5",
|
||||||
|
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.5",
|
||||||
|
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4",
|
||||||
|
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-haiku-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Haiku 4.5",
|
||||||
|
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
// --- Chat Variant (No tool calling, for pure conversation) ---
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-opus-4.5-chat",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Opus 4.5 (Chat)",
|
||||||
|
Description: "Claude Opus 4.5 for chat only (no tool calling)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-opus-4.5-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Opus 4.5 (Agentic)",
|
||||||
|
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4.5-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)",
|
||||||
|
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions.
|
||||||
|
// These models use the same API as Kiro and share the same executor.
|
||||||
|
func GetAmazonQModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "amazonq-auto",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro", // Uses Kiro executor - same API
|
||||||
|
DisplayName: "Amazon Q Auto",
|
||||||
|
Description: "Automatic model selection by Amazon Q",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-opus-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Opus 4.5",
|
||||||
|
Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-sonnet-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Sonnet 4.5",
|
||||||
|
Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-sonnet-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Sonnet 4",
|
||||||
|
Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-haiku-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Haiku 4.5",
|
||||||
|
Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -309,7 +309,9 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||||
|
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||||
|
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||||
|
|||||||
@@ -12,11 +12,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -26,21 +28,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||||
antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||||
antigravityGeneratePath = "/v1internal:generateContent"
|
antigravityGeneratePath = "/v1internal:generateContent"
|
||||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
antigravityClientSecret = "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"
|
antigravityClientSecret = "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"
|
||||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||||
antigravityAuthType = "antigravity"
|
antigravityAuthType = "antigravity"
|
||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
streamScannerBuffer int = 20_971_520
|
streamScannerBuffer int = 20_971_520
|
||||||
)
|
)
|
||||||
|
|
||||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
var (
|
||||||
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
randSourceMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
@@ -76,6 +81,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
|
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -169,6 +176,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||||
|
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||||
|
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||||
|
|
||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
@@ -365,28 +374,34 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
modelConfig := registry.GetAntigravityModelConfig()
|
||||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||||
for id := range result.Map() {
|
for originalName := range result.Map() {
|
||||||
id = modelName2Alias(id)
|
aliasName := modelName2Alias(originalName)
|
||||||
if id != "" {
|
if aliasName != "" {
|
||||||
|
cfg := modelConfig[aliasName]
|
||||||
|
modelName := aliasName
|
||||||
|
if cfg != nil && cfg.Name != "" {
|
||||||
|
modelName = cfg.Name
|
||||||
|
}
|
||||||
modelInfo := ®istry.ModelInfo{
|
modelInfo := ®istry.ModelInfo{
|
||||||
ID: id,
|
ID: aliasName,
|
||||||
Name: id,
|
Name: modelName,
|
||||||
Description: id,
|
Description: aliasName,
|
||||||
DisplayName: id,
|
DisplayName: aliasName,
|
||||||
Version: id,
|
Version: aliasName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: now,
|
Created: now,
|
||||||
OwnedBy: antigravityAuthType,
|
OwnedBy: antigravityAuthType,
|
||||||
Type: antigravityAuthType,
|
Type: antigravityAuthType,
|
||||||
}
|
}
|
||||||
// Add Thinking support for thinking models
|
// Look up Thinking support from static config using alias name
|
||||||
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
|
if cfg != nil {
|
||||||
modelInfo.Thinking = ®istry.ThinkingSupport{
|
if cfg.Thinking != nil {
|
||||||
Min: 1024,
|
modelInfo.Thinking = cfg.Thinking
|
||||||
Max: 100000,
|
}
|
||||||
ZeroAllowed: false,
|
if cfg.MaxCompletionTokens > 0 {
|
||||||
DynamicAllowed: true,
|
modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
models = append(models, modelInfo)
|
models = append(models, modelInfo)
|
||||||
@@ -508,8 +523,49 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
requestURL.WriteString(url.QueryEscape(alt))
|
requestURL.WriteString(url.QueryEscape(alt))
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = geminiToAntigravity(modelName, payload)
|
// Extract project_id from auth metadata if available
|
||||||
|
projectID := ""
|
||||||
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
if pid, ok := auth.Metadata["project_id"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||||
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
|
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
|
||||||
|
|
||||||
|
if strings.Contains(modelName, "claude") {
|
||||||
|
strJSON := string(payload)
|
||||||
|
paths := make([]string, 0)
|
||||||
|
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||||
|
for _, p := range paths {
|
||||||
|
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
strJSON = util.DeleteKey(strJSON, "$schema")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "maxItems")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "minItems")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "$ref")
|
||||||
|
strJSON = util.DeleteKey(strJSON, "$defs")
|
||||||
|
|
||||||
|
paths = make([]string, 0)
|
||||||
|
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||||
|
for _, p := range paths {
|
||||||
|
anyOf := gjson.Get(strJSON, p)
|
||||||
|
if anyOf.IsArray() {
|
||||||
|
anyOfItems := anyOf.Array()
|
||||||
|
if len(anyOfItems) > 0 {
|
||||||
|
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = []byte(strJSON)
|
||||||
|
}
|
||||||
|
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return nil, errReq
|
return nil, errReq
|
||||||
@@ -609,7 +665,7 @@ func buildBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
|
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
|
||||||
return baseURLs[0]
|
return baseURLs[0]
|
||||||
}
|
}
|
||||||
return antigravityBaseURLAutopush
|
return antigravityBaseURLDaily
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveHost(base string) string {
|
func resolveHost(base string) string {
|
||||||
@@ -645,7 +701,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
|||||||
}
|
}
|
||||||
return []string{
|
return []string{
|
||||||
antigravityBaseURLDaily,
|
antigravityBaseURLDaily,
|
||||||
antigravityBaseURLAutopush,
|
// antigravityBaseURLAutopush,
|
||||||
antigravityBaseURLProd,
|
antigravityBaseURLProd,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -670,16 +726,22 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiToAntigravity(modelName string, payload []byte) []byte {
|
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
|
||||||
|
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||||
|
if projectID != "" {
|
||||||
|
template, _ = sjson.Set(template, "project", projectID)
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||||
|
}
|
||||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||||
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
|
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
|
||||||
|
|
||||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
|
||||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
if !strings.HasPrefix(modelName, "gemini-3-") {
|
||||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
||||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||||
@@ -687,7 +749,7 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(modelName, "claude-sonnet-") {
|
if strings.Contains(modelName, "claude") {
|
||||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||||
@@ -699,6 +761,8 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
|
|||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(template)
|
return []byte(template)
|
||||||
@@ -709,15 +773,19 @@ func generateRequestID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateSessionID() string {
|
func generateSessionID() string {
|
||||||
|
randSourceMutex.Lock()
|
||||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||||
|
randSourceMutex.Unlock()
|
||||||
return "-" + strconv.FormatInt(n, 10)
|
return "-" + strconv.FormatInt(n, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateProjectID() string {
|
func generateProjectID() string {
|
||||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||||
|
randSourceMutex.Lock()
|
||||||
adj := adjectives[randSource.Intn(len(adjectives))]
|
adj := adjectives[randSource.Intn(len(adjectives))]
|
||||||
noun := nouns[randSource.Intn(len(nouns))]
|
noun := nouns[randSource.Intn(len(nouns))]
|
||||||
|
randSourceMutex.Unlock()
|
||||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||||
return adj + "-" + noun + "-" + randomPart
|
return adj + "-" + noun + "-" + randomPart
|
||||||
}
|
}
|
||||||
@@ -761,3 +829,65 @@ func alias2ModelName(modelName string) string {
|
|||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||||
|
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||||
|
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||||
|
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||||
|
if !util.ModelSupportsThinking(model) {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||||
|
if !budget.Exists() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
raw := int(budget.Int())
|
||||||
|
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||||
|
|
||||||
|
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
||||||
|
if isClaude {
|
||||||
|
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||||
|
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||||
|
normalized = effectiveMax - 1
|
||||||
|
}
|
||||||
|
minBudget := antigravityMinThinkingBudget(model)
|
||||||
|
if minBudget > 0 && normalized >= 0 && normalized < minBudget {
|
||||||
|
// Budget is below minimum, remove thinking config entirely
|
||||||
|
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
if setDefaultMax {
|
||||||
|
if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil {
|
||||||
|
payload = res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
||||||
|
if err != nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravityEffectiveMaxTokens returns the max tokens to cap thinking:
|
||||||
|
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
|
||||||
|
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||||
|
func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) {
|
||||||
|
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||||
|
return int(maxTok.Int()), false
|
||||||
|
}
|
||||||
|
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||||
|
return modelInfo.MaxCompletionTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravityMinThinkingBudget returns the minimum thinking budget for a model.
|
||||||
|
// Falls back to -1 if no model info is found.
|
||||||
|
func antigravityMinThinkingBudget(model string) int {
|
||||||
|
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.Thinking != nil {
|
||||||
|
return modelInfo.Thinking.Min
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,38 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type codexCache struct {
|
type codexCache struct {
|
||||||
ID string
|
ID string
|
||||||
Expire time.Time
|
Expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var codexCacheMap = map[string]codexCache{}
|
var (
|
||||||
|
codexCacheMap = map[string]codexCache{}
|
||||||
|
codexCacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// getCodexCache safely retrieves a cache entry
|
||||||
|
func getCodexCache(key string) (codexCache, bool) {
|
||||||
|
codexCacheMutex.RLock()
|
||||||
|
defer codexCacheMutex.RUnlock()
|
||||||
|
cache, ok := codexCacheMap[key]
|
||||||
|
return cache, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCodexCache safely sets a cache entry
|
||||||
|
func setCodexCache(key string, cache codexCache) {
|
||||||
|
codexCacheMutex.Lock()
|
||||||
|
defer codexCacheMutex.Unlock()
|
||||||
|
codexCacheMap[key] = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCodexCache safely deletes a cache entry
|
||||||
|
func deleteCodexCache(key string) {
|
||||||
|
codexCacheMutex.Lock()
|
||||||
|
defer codexCacheMutex.Unlock()
|
||||||
|
delete(codexCacheMap, key)
|
||||||
|
}
|
||||||
|
|||||||
@@ -506,12 +506,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
if userIDResult.Exists() {
|
if userIDResult.Exists() {
|
||||||
var hasKey bool
|
var hasKey bool
|
||||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||||
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
|
if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) {
|
||||||
cache = codexCache{
|
cache = codexCache{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Expire: time.Now().Add(1 * time.Hour),
|
Expire: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
codexCacheMap[key] = cache
|
setCodexCache(key, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if from == "openai-response" {
|
} else if from == "openai-response" {
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||||
|
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||||
|
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||||
@@ -199,6 +201,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
to := sdktranslator.FromString("gemini-cli")
|
to := sdktranslator.FromString("gemini-cli")
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||||
|
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||||
|
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||||
|
|||||||
@@ -80,6 +80,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
@@ -169,6 +171,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|||||||
@@ -296,6 +296,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
@@ -391,6 +393,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
@@ -487,6 +491,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
@@ -599,6 +605,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||||
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|||||||
2391
internal/runtime/executor/kiro_executor.go
Normal file
2391
internal/runtime/executor/kiro_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -14,11 +15,19 @@ import (
|
|||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||||
|
var (
|
||||||
|
httpClientCache = make(map[string]*http.Client)
|
||||||
|
httpClientCacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||||
// 3. Use RoundTripper from context if neither are configured
|
// 3. Use RoundTripper from context if neither are configured
|
||||||
//
|
//
|
||||||
|
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
|
||||||
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ctx: The context containing optional RoundTripper
|
// - ctx: The context containing optional RoundTripper
|
||||||
// - cfg: The application configuration
|
// - cfg: The application configuration
|
||||||
@@ -28,11 +37,6 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *http.Client: An HTTP client with configured proxy or transport
|
// - *http.Client: An HTTP client with configured proxy or transport
|
||||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
httpClient := &http.Client{}
|
|
||||||
if timeout > 0 {
|
|
||||||
httpClient.Timeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 1: Use auth.ProxyURL if configured
|
// Priority 1: Use auth.ProxyURL if configured
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build cache key from proxy URL (empty string for no proxy)
|
||||||
|
cacheKey := proxyURL
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
httpClientCacheMutex.RLock()
|
||||||
|
if cachedClient, ok := httpClientCache[cacheKey]; ok {
|
||||||
|
httpClientCacheMutex.RUnlock()
|
||||||
|
// Return a wrapper with the requested timeout but shared transport
|
||||||
|
if timeout > 0 {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: cachedClient.Transport,
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cachedClient
|
||||||
|
}
|
||||||
|
httpClientCacheMutex.RUnlock()
|
||||||
|
|
||||||
|
// Create new client
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
if timeout > 0 {
|
||||||
|
httpClient.Timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
// If we have a proxy URL configured, set up the transport
|
// If we have a proxy URL configured, set up the transport
|
||||||
if proxyURL != "" {
|
if proxyURL != "" {
|
||||||
transport := buildProxyTransport(proxyURL)
|
transport := buildProxyTransport(proxyURL)
|
||||||
if transport != nil {
|
if transport != nil {
|
||||||
httpClient.Transport = transport
|
httpClient.Transport = transport
|
||||||
|
// Cache the client
|
||||||
|
httpClientCacheMutex.Lock()
|
||||||
|
httpClientCache[cacheKey] = httpClient
|
||||||
|
httpClientCacheMutex.Unlock()
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
// If proxy setup failed, log and fall through to context RoundTripper
|
// If proxy setup failed, log and fall through to context RoundTripper
|
||||||
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
httpClient.Transport = rt
|
httpClient.Transport = rt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache the client for no-proxy case
|
||||||
|
if proxyURL == "" {
|
||||||
|
httpClientCacheMutex.Lock()
|
||||||
|
httpClientCache[cacheKey] = httpClient
|
||||||
|
httpClientCacheMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -83,18 +83,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
for j := 0; j < len(contentResults); j++ {
|
for j := 0; j < len(contentResults); j++ {
|
||||||
contentResult := contentResults[j]
|
contentResult := contentResults[j]
|
||||||
contentTypeResult := contentResult.Get("type")
|
contentTypeResult := contentResult.Get("type")
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||||
|
prompt := contentResult.Get("thinking").String()
|
||||||
|
signatureResult := contentResult.Get("signature")
|
||||||
|
signature := geminiCLIClaudeThoughtSignature
|
||||||
|
if signatureResult.Exists() {
|
||||||
|
signature = signatureResult.String()
|
||||||
|
}
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||||
prompt := contentResult.Get("text").String()
|
prompt := contentResult.Get("text").String()
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||||
functionName := contentResult.Get("name").String()
|
functionName := contentResult.Get("name").String()
|
||||||
functionArgs := contentResult.Get("input").String()
|
functionArgs := contentResult.Get("input").String()
|
||||||
|
functionID := contentResult.Get("id").String()
|
||||||
var args map[string]any
|
var args map[string]any
|
||||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
if strings.Contains(modelName, "claude") {
|
||||||
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
|
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||||
|
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||||
|
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||||
toolCallID := contentResult.Get("tool_use_id").String()
|
toolCallID := contentResult.Get("tool_use_id").String()
|
||||||
@@ -105,9 +120,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||||
}
|
}
|
||||||
responseData := contentResult.Get("content").Raw
|
responseData := contentResult.Get("content").Raw
|
||||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||||
}
|
}
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||||
|
sourceResult := contentResult.Get("source")
|
||||||
|
if sourceResult.Get("type").String() == "base64" {
|
||||||
|
inlineData := &client.InlineData{
|
||||||
|
MimeType: sourceResult.Get("media_type").String(),
|
||||||
|
Data: sourceResult.Get("data").String(),
|
||||||
|
}
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
contents = append(contents, clientContent)
|
contents = append(contents, clientContent)
|
||||||
@@ -165,7 +189,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if t.Get("type").String() == "enabled" {
|
if t.Get("type").String() == "enabled" {
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
budget = util.NormalizeThinkingBudget(modelName, budget)
|
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
@@ -180,6 +203,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
|
||||||
}
|
}
|
||||||
|
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
|
||||||
|
}
|
||||||
|
|
||||||
outBytes := []byte(out)
|
outBytes := []byte(out)
|
||||||
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -36,6 +37,9 @@ type Params struct {
|
|||||||
HasToolUse bool // Indicates if tool use was observed in the stream
|
HasToolUse bool // Indicates if tool use was observed in the stream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||||
|
var toolUseIDCounter uint64
|
||||||
|
|
||||||
// ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
// This function implements a complex state machine that translates backend client responses
|
// This function implements a complex state machine that translates backend client responses
|
||||||
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
|
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
@@ -111,8 +115,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
if partTextResult.Exists() {
|
if partTextResult.Exists() {
|
||||||
// Process thinking content (internal reasoning)
|
// Process thinking content (internal reasoning)
|
||||||
if partResult.Get("thought").Bool() {
|
if partResult.Get("thought").Bool() {
|
||||||
// Continue existing thinking block if already in thinking state
|
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||||
if params.ResponseType == 2 {
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||||
output = output + "event: content_block_delta\n"
|
output = output + "event: content_block_delta\n"
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
@@ -141,35 +148,39 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
params.ResponseType = 2 // Set state to thinking
|
params.ResponseType = 2 // Set state to thinking
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Process regular text content (user-visible output)
|
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
|
||||||
// Continue existing text block if already in content state
|
if partTextResult.String() != "" || !finishReasonResult.Exists() {
|
||||||
if params.ResponseType == 1 {
|
// Process regular text content (user-visible output)
|
||||||
output = output + "event: content_block_delta\n"
|
// Continue existing text block if already in content state
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
if params.ResponseType == 1 {
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + "event: content_block_delta\n"
|
||||||
} else {
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||||
// Transition from another state to text content
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
// First, close any existing content block
|
} else {
|
||||||
if params.ResponseType != 0 {
|
// Transition from another state to text content
|
||||||
if params.ResponseType == 2 {
|
// First, close any existing content block
|
||||||
// output = output + "event: content_block_delta\n"
|
if params.ResponseType != 0 {
|
||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
if params.ResponseType == 2 {
|
||||||
// output = output + "\n\n\n"
|
// output = output + "event: content_block_delta\n"
|
||||||
|
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||||
|
// output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
params.ResponseIndex++
|
||||||
|
}
|
||||||
|
if partTextResult.String() != "" {
|
||||||
|
// Start a new text content block
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
params.ResponseType = 1 // Set state to content
|
||||||
}
|
}
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
params.ResponseIndex++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new text content block
|
|
||||||
output = output + "event: content_block_start\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
params.ResponseType = 1 // Set state to content
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if functionCallResult.Exists() {
|
} else if functionCallResult.Exists() {
|
||||||
@@ -209,7 +220,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
|||||||
@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "low":
|
case "low":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "medium":
|
case "medium":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "high":
|
case "high":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
default:
|
default:
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||||
var setBudget bool
|
var setBudget bool
|
||||||
var normalized int
|
var budget int
|
||||||
|
|
||||||
if v := tc.Get("thinkingBudget"); v.Exists() {
|
if v := tc.Get("thinkingBudget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,22 +82,27 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if setBudget && normalized != 0 {
|
} else if setBudget && budget != 0 {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
|
// Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens
|
||||||
// This matches the official Gemini CLI behavior which always sends:
|
// This allows Claude Code and other Claude API clients to pass thinking configuration
|
||||||
// { thinkingBudget: -1, includeThoughts: true }
|
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) {
|
||||||
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
|
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||||
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
|
if t.Get("type").String() == "enabled" {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
budget := int(b.Int())
|
||||||
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Temperature/top_p/top_k
|
// Temperature/top_p/top_k/max_tokens
|
||||||
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
|
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
|
||||||
}
|
}
|
||||||
@@ -107,6 +112,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
|
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
||||||
}
|
}
|
||||||
|
if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
|
||||||
|
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
|
||||||
|
}
|
||||||
|
|
||||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||||
@@ -251,6 +259,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
fid := tc.Get("id").String()
|
fid := tc.Get("id").String()
|
||||||
fname := tc.Get("function.name").String()
|
fname := tc.Get("function.name").String()
|
||||||
fargs := tc.Get("function.arguments").String()
|
fargs := tc.Get("function.arguments").String()
|
||||||
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||||
@@ -262,10 +271,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||||
|
|
||||||
// Append a single tool content combining name + response per function
|
// Append a single tool content combining name + response per function
|
||||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||||
pp := 0
|
pp := 0
|
||||||
for _, fid := range fIDs {
|
for _, fid := range fIDs {
|
||||||
if name, ok := tcID2Name[fid]; ok {
|
if name, ok := tcID2Name[fid]; ok {
|
||||||
|
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||||
resp := toolResponses[fid]
|
resp := toolResponses[fid]
|
||||||
if resp == "" {
|
if resp == "" {
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||||
@@ -23,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct {
|
|||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
|
var functionCallIDCounter uint64
|
||||||
|
|
||||||
// ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the
|
// ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the
|
||||||
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
|
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
|
||||||
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
|
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
|
||||||
@@ -75,8 +80,8 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Extract and set the finish reason.
|
// Extract and set the finish reason.
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
@@ -145,7 +150,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
|||||||
@@ -331,8 +331,9 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
|
|||||||
streamingEvents := make([][]byte, 0)
|
streamingEvents := make([][]byte, 0)
|
||||||
|
|
||||||
scanner := bufio.NewScanner(bytes.NewReader(rawJSON))
|
scanner := bufio.NewScanner(bytes.NewReader(rawJSON))
|
||||||
buffer := make([]byte, 20_971_520)
|
// Use a smaller initial buffer (64KB) that can grow up to 20MB if needed
|
||||||
scanner.Buffer(buffer, 20_971_520)
|
// This prevents allocating 20MB for every request regardless of size
|
||||||
|
scanner.Buffer(make([]byte, 64*1024), 20_971_520)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
// log.Debug(string(line))
|
// log.Debug(string(line))
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ type ToolCallAccumulator struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||||
|
var localParam any
|
||||||
|
if param == nil {
|
||||||
|
param = &localParam
|
||||||
|
}
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &ConvertAnthropicResponseToOpenAIParams{
|
*param = &ConvertAnthropicResponseToOpenAIParams{
|
||||||
CreatedAt: 0,
|
CreatedAt: 0,
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
|
|||||||
func mustMarshalJSON(v interface{}) string {
|
func mustMarshalJSON(v interface{}) string {
|
||||||
data, err := json.Marshal(v)
|
data, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return ""
|
||||||
}
|
}
|
||||||
return string(data)
|
return string(data)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
if t.Get("type").String() == "enabled" {
|
if t.Get("type").String() == "enabled" {
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
budget = util.NormalizeThinkingBudget(modelName, budget)
|
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -27,6 +28,9 @@ type Params struct {
|
|||||||
ResponseIndex int // Index counter for content blocks in the streaming response
|
ResponseIndex int // Index counter for content blocks in the streaming response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||||
|
var toolUseIDCounter uint64
|
||||||
|
|
||||||
// ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
// This function implements a complex state machine that translates backend client responses
|
// This function implements a complex state machine that translates backend client responses
|
||||||
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
|
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
@@ -60,12 +64,12 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
|
|
||||||
// Track whether tools are being used in this response chunk
|
// Track whether tools are being used in this response chunk
|
||||||
usedTool := false
|
usedTool := false
|
||||||
output := ""
|
var sb strings.Builder
|
||||||
|
|
||||||
// Initialize the streaming session with a message_start event
|
// Initialize the streaming session with a message_start event
|
||||||
// This is only sent for the very first response chunk to establish the streaming session
|
// This is only sent for the very first response chunk to establish the streaming session
|
||||||
if !(*param).(*Params).HasFirstResponse {
|
if !(*param).(*Params).HasFirstResponse {
|
||||||
output = "event: message_start\n"
|
sb.WriteString("event: message_start\n")
|
||||||
|
|
||||||
// Create the initial message structure with default values according to Claude Code API specification
|
// Create the initial message structure with default values according to Claude Code API specification
|
||||||
// This follows the Claude Code API specification for streaming message initialization
|
// This follows the Claude Code API specification for streaming message initialization
|
||||||
@@ -78,7 +82,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", messageStartTemplate))
|
||||||
|
|
||||||
(*param).(*Params).HasFirstResponse = true
|
(*param).(*Params).HasFirstResponse = true
|
||||||
}
|
}
|
||||||
@@ -101,62 +105,52 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
if partResult.Get("thought").Bool() {
|
if partResult.Get("thought").Bool() {
|
||||||
// Continue existing thinking block if already in thinking state
|
// Continue existing thinking block if already in thinking state
|
||||||
if (*param).(*Params).ResponseType == 2 {
|
if (*param).(*Params).ResponseType == 2 {
|
||||||
output = output + "event: content_block_delta\n"
|
sb.WriteString("event: content_block_delta\n")
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
} else {
|
} else {
|
||||||
// Transition from another state to thinking
|
// Transition from another state to thinking
|
||||||
// First, close any existing content block
|
// First, close any existing content block
|
||||||
if (*param).(*Params).ResponseType != 0 {
|
if (*param).(*Params).ResponseType != 0 {
|
||||||
if (*param).(*Params).ResponseType == 2 {
|
sb.WriteString("event: content_block_stop\n")
|
||||||
// output = output + "event: content_block_delta\n"
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString("\n\n\n")
|
||||||
// output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
(*param).(*Params).ResponseIndex++
|
(*param).(*Params).ResponseIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new thinking content block
|
// Start a new thinking content block
|
||||||
output = output + "event: content_block_start\n"
|
sb.WriteString("event: content_block_start\n")
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex))
|
||||||
output = output + "\n\n\n"
|
sb.WriteString("\n\n\n")
|
||||||
output = output + "event: content_block_delta\n"
|
sb.WriteString("event: content_block_delta\n")
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
(*param).(*Params).ResponseType = 2 // Set state to thinking
|
(*param).(*Params).ResponseType = 2 // Set state to thinking
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Process regular text content (user-visible output)
|
// Process regular text content (user-visible output)
|
||||||
// Continue existing text block if already in content state
|
// Continue existing text block if already in content state
|
||||||
if (*param).(*Params).ResponseType == 1 {
|
if (*param).(*Params).ResponseType == 1 {
|
||||||
output = output + "event: content_block_delta\n"
|
sb.WriteString("event: content_block_delta\n")
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
} else {
|
} else {
|
||||||
// Transition from another state to text content
|
// Transition from another state to text content
|
||||||
// First, close any existing content block
|
// First, close any existing content block
|
||||||
if (*param).(*Params).ResponseType != 0 {
|
if (*param).(*Params).ResponseType != 0 {
|
||||||
if (*param).(*Params).ResponseType == 2 {
|
sb.WriteString("event: content_block_stop\n")
|
||||||
// output = output + "event: content_block_delta\n"
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString("\n\n\n")
|
||||||
// output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
(*param).(*Params).ResponseIndex++
|
(*param).(*Params).ResponseIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new text content block
|
// Start a new text content block
|
||||||
output = output + "event: content_block_start\n"
|
sb.WriteString("event: content_block_start\n")
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex))
|
||||||
output = output + "\n\n\n"
|
sb.WriteString("\n\n\n")
|
||||||
output = output + "event: content_block_delta\n"
|
sb.WriteString("event: content_block_delta\n")
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
(*param).(*Params).ResponseType = 1 // Set state to content
|
(*param).(*Params).ResponseType = 1 // Set state to content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,42 +163,35 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
// Handle state transitions when switching to function calls
|
// Handle state transitions when switching to function calls
|
||||||
// Close any existing function call block first
|
// Close any existing function call block first
|
||||||
if (*param).(*Params).ResponseType == 3 {
|
if (*param).(*Params).ResponseType == 3 {
|
||||||
output = output + "event: content_block_stop\n"
|
sb.WriteString("event: content_block_stop\n")
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||||
output = output + "\n\n\n"
|
sb.WriteString("\n\n\n")
|
||||||
(*param).(*Params).ResponseIndex++
|
(*param).(*Params).ResponseIndex++
|
||||||
(*param).(*Params).ResponseType = 0
|
(*param).(*Params).ResponseType = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special handling for thinking state transition
|
|
||||||
if (*param).(*Params).ResponseType == 2 {
|
|
||||||
// output = output + "event: content_block_delta\n"
|
|
||||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
|
|
||||||
// output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close any other existing content block
|
// Close any other existing content block
|
||||||
if (*param).(*Params).ResponseType != 0 {
|
if (*param).(*Params).ResponseType != 0 {
|
||||||
output = output + "event: content_block_stop\n"
|
sb.WriteString("event: content_block_stop\n")
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||||
output = output + "\n\n\n"
|
sb.WriteString("\n\n\n")
|
||||||
(*param).(*Params).ResponseIndex++
|
(*param).(*Params).ResponseIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a new tool use content block
|
// Start a new tool use content block
|
||||||
// This creates the structure for a function call in Claude Code format
|
// This creates the structure for a function call in Claude Code format
|
||||||
output = output + "event: content_block_start\n"
|
sb.WriteString("event: content_block_start\n")
|
||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
output = output + "event: content_block_delta\n"
|
sb.WriteString("event: content_block_delta\n")
|
||||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
sb.WriteString(fmt.Sprintf("data: %s\n\n\n", data))
|
||||||
}
|
}
|
||||||
(*param).(*Params).ResponseType = 3
|
(*param).(*Params).ResponseType = 3
|
||||||
}
|
}
|
||||||
@@ -216,13 +203,13 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
|
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
// Close the final content block
|
// Close the final content block
|
||||||
output = output + "event: content_block_stop\n"
|
sb.WriteString("event: content_block_stop\n")
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
|
sb.WriteString(fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex))
|
||||||
output = output + "\n\n\n"
|
sb.WriteString("\n\n\n")
|
||||||
|
|
||||||
// Send the final message delta with usage information and stop reason
|
// Send the final message delta with usage information and stop reason
|
||||||
output = output + "event: message_delta\n"
|
sb.WriteString("event: message_delta\n")
|
||||||
output = output + `data: `
|
sb.WriteString(`data: `)
|
||||||
|
|
||||||
// Create the message delta template with appropriate stop reason
|
// Create the message delta template with appropriate stop reason
|
||||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
@@ -236,11 +223,11 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||||
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||||
|
|
||||||
output = output + template + "\n\n\n"
|
sb.WriteString(template + "\n\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{output}
|
return []string{sb.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
|
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
|
||||||
|
|||||||
@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "low":
|
case "low":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "medium":
|
case "medium":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "high":
|
case "high":
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
default:
|
default:
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||||
var setBudget bool
|
var setBudget bool
|
||||||
var normalized int
|
var budget int
|
||||||
|
|
||||||
if v := tc.Get("thinkingBudget"); v.Exists() {
|
if v := tc.Get("thinkingBudget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,21 +82,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if setBudget && normalized != 0 {
|
} else if setBudget && budget != 0 {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
|
|
||||||
// This matches the official Gemini CLI behavior which always sends:
|
|
||||||
// { thinkingBudget: -1, includeThoughts: true }
|
|
||||||
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
|
|
||||||
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
|
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Temperature/top_p/top_k
|
// Temperature/top_p/top_k
|
||||||
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
|
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
|
||||||
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
|
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||||
@@ -23,6 +25,9 @@ type convertCliResponseToOpenAIChatParams struct {
|
|||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
|
var functionCallIDCounter uint64
|
||||||
|
|
||||||
// ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the
|
// ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the
|
||||||
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
|
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
|
||||||
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
|
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
|
||||||
@@ -75,8 +80,8 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
|
|
||||||
// Extract and set the finish reason.
|
// Extract and set the finish reason.
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
@@ -145,7 +150,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
|
|
||||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
|||||||
@@ -158,7 +158,6 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
if t.Get("type").String() == "enabled" {
|
if t.Get("type").String() == "enabled" {
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
budget = util.NormalizeThinkingBudget(modelName, budget)
|
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -26,6 +27,9 @@ type Params struct {
|
|||||||
ResponseIndex int
|
ResponseIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||||
|
var toolUseIDCounter uint64
|
||||||
|
|
||||||
// ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
// This function implements a complex state machine that translates backend client responses
|
// This function implements a complex state machine that translates backend client responses
|
||||||
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
@@ -197,7 +201,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
|||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
|||||||
@@ -48,13 +48,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "low":
|
case "low":
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "medium":
|
case "medium":
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "high":
|
case "high":
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
default:
|
default:
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||||
var setBudget bool
|
var setBudget bool
|
||||||
var normalized int
|
var budget int
|
||||||
|
|
||||||
if v := tc.Get("thinkingBudget"); v.Exists() {
|
if v := tc.Get("thinkingBudget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
} else if v := tc.Get("thinking_budget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
} else if v := tc.Get("include_thoughts"); v.Exists() {
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if setBudget && normalized != 0 {
|
} else if setBudget && budget != 0 {
|
||||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -22,6 +24,9 @@ type convertGeminiResponseToOpenAIChatParams struct {
|
|||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
|
var functionCallIDCounter uint64
|
||||||
|
|
||||||
// ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the
|
// ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the
|
||||||
// Gemini API format to the OpenAI Chat Completions streaming format.
|
// Gemini API format to the OpenAI Chat Completions streaming format.
|
||||||
// It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses.
|
// It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses.
|
||||||
@@ -78,8 +83,8 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
|
|
||||||
// Extract and set the finish reason.
|
// Extract and set the finish reason.
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
@@ -147,7 +152,7 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
|
|
||||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
@@ -230,8 +235,8 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
|
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||||
@@ -280,7 +285,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
|
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
|
||||||
functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature)
|
functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature)
|
||||||
|
functionCall, _ = sjson.Set(functionCall, "functionCall.id", item.Get("call_id").String())
|
||||||
|
|
||||||
// Parse arguments JSON string and set as args object
|
// Parse arguments JSON string and set as args object
|
||||||
if arguments != "" {
|
if arguments != "" {
|
||||||
@@ -285,6 +286,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
|
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
|
||||||
|
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.id", callID)
|
||||||
|
|
||||||
// Set the raw JSON output directly (preserves string encoding)
|
// Set the raw JSON output directly (preserves string encoding)
|
||||||
if outputRaw != "" && outputRaw != "null" {
|
if outputRaw != "" && outputRaw != "null" {
|
||||||
@@ -398,16 +400,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "minimal":
|
case "minimal":
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 1024))
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "low":
|
case "low":
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 4096))
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "medium":
|
case "medium":
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 8192))
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
case "high":
|
case "high":
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", util.NormalizeThinkingBudget(modelName, 32768))
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
default:
|
default:
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
@@ -419,32 +421,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||||
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||||
var setBudget bool
|
var setBudget bool
|
||||||
var normalized int
|
var budget int
|
||||||
if v := tc.Get("thinking_budget"); v.Exists() {
|
if v := tc.Get("thinking_budget"); v.Exists() {
|
||||||
normalized = util.NormalizeThinkingBudget(modelName, int(v.Int()))
|
budget = int(v.Int())
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", normalized)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
setBudget = true
|
setBudget = true
|
||||||
}
|
}
|
||||||
if v := tc.Get("include_thoughts"); v.Exists() {
|
if v := tc.Get("include_thoughts"); v.Exists() {
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", v.Bool())
|
||||||
} else if setBudget {
|
} else if setBudget {
|
||||||
if normalized != 0 {
|
if budget != 0 {
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
|
|
||||||
// This matches the official Gemini CLI behavior which always sends:
|
|
||||||
// { thinkingBudget: -1, includeThoughts: true }
|
|
||||||
// See: ai-gemini-cli/packages/core/src/config/defaultModelConfigs.ts
|
|
||||||
if !gjson.Get(out, "generationConfig.thinkingConfig").Exists() && modelName == "gemini-3-pro-preview" {
|
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
|
||||||
// log.Debugf("Applied default thinkingConfig for gemini-3-pro-preview (matches Gemini CLI): thinkingBudget=-1, include_thoughts=true")
|
|
||||||
}
|
|
||||||
|
|
||||||
result := []byte(out)
|
result := []byte(out)
|
||||||
result = common.AttachDefaultSafetySettings(result, "safetySettings")
|
result = common.AttachDefaultSafetySettings(result, "safetySettings")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -37,6 +38,12 @@ type geminiToResponsesState struct {
|
|||||||
FuncCallIDs map[int]string
|
FuncCallIDs map[int]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||||
|
var responseIDCounter uint64
|
||||||
|
|
||||||
|
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
|
var funcCallIDCounter uint64
|
||||||
|
|
||||||
func emitEvent(event string, payload string) string {
|
func emitEvent(event string, payload string) string {
|
||||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||||
}
|
}
|
||||||
@@ -205,7 +212,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||||
}
|
}
|
||||||
if st.FuncCallIDs[idx] == "" {
|
if st.FuncCallIDs[idx] == "" {
|
||||||
st.FuncCallIDs[idx] = fmt.Sprintf("call_%d", time.Now().UnixNano())
|
st.FuncCallIDs[idx] = fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
|
||||||
}
|
}
|
||||||
st.FuncNames[idx] = name
|
st.FuncNames[idx] = name
|
||||||
|
|
||||||
@@ -464,7 +471,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
// id: prefer provider responseId, otherwise synthesize
|
// id: prefer provider responseId, otherwise synthesize
|
||||||
id := root.Get("responseId").String()
|
id := root.Get("responseId").String()
|
||||||
if id == "" {
|
if id == "" {
|
||||||
id = fmt.Sprintf("resp_%x", time.Now().UnixNano())
|
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
|
||||||
}
|
}
|
||||||
// Normalize to response-style id (prefix resp_ if missing)
|
// Normalize to response-style id (prefix resp_ if missing)
|
||||||
if !strings.HasPrefix(id, "resp_") {
|
if !strings.HasPrefix(id, "resp_") {
|
||||||
@@ -575,7 +582,7 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
if fc := p.Get("functionCall"); fc.Exists() {
|
if fc := p.Get("functionCall"); fc.Exists() {
|
||||||
name := fc.Get("name").String()
|
name := fc.Get("name").String()
|
||||||
args := fc.Get("args")
|
args := fc.Get("args")
|
||||||
callID := fmt.Sprintf("call_%x", time.Now().UnixNano())
|
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
|
||||||
outputs = append(outputs, map[string]interface{}{
|
outputs = append(outputs, map[string]interface{}{
|
||||||
"id": fmt.Sprintf("fc_%s", callID),
|
"id": fmt.Sprintf("fc_%s", callID),
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
|
|||||||
@@ -33,4 +33,7 @@ import (
|
|||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
||||||
|
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai/chat-completions"
|
||||||
)
|
)
|
||||||
|
|||||||
19
internal/translator/kiro/claude/init.go
Normal file
19
internal/translator/kiro/claude/init.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
translator.Register(
|
||||||
|
Claude,
|
||||||
|
Kiro,
|
||||||
|
ConvertClaudeRequestToKiro,
|
||||||
|
interfaces.TranslateResponse{
|
||||||
|
Stream: ConvertKiroResponseToClaude,
|
||||||
|
NonStream: ConvertKiroResponseToClaudeNonStream,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
27
internal/translator/kiro/claude/kiro_claude.go
Normal file
27
internal/translator/kiro/claude/kiro_claude.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
// Package claude provides translation between Kiro and Claude formats.
|
||||||
|
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
|
||||||
|
// translations are pass-through.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertClaudeRequestToKiro converts Claude request to Kiro format.
|
||||||
|
// Since Kiro uses Claude format internally, this is mostly a pass-through.
|
||||||
|
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||||
|
return bytes.Clone(inputRawJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertKiroResponseToClaude converts Kiro streaming response to Claude format.
|
||||||
|
// Kiro executor already generates complete SSE format with "event:" prefix,
|
||||||
|
// so this is a simple pass-through.
|
||||||
|
func ConvertKiroResponseToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||||
|
return []string{string(rawResponse)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertKiroResponseToClaudeNonStream converts Kiro non-streaming response to Claude format.
|
||||||
|
func ConvertKiroResponseToClaudeNonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||||
|
return string(rawResponse)
|
||||||
|
}
|
||||||
19
internal/translator/kiro/openai/chat-completions/init.go
Normal file
19
internal/translator/kiro/openai/chat-completions/init.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
translator.Register(
|
||||||
|
OpenAI,
|
||||||
|
Kiro,
|
||||||
|
ConvertOpenAIRequestToKiro,
|
||||||
|
interfaces.TranslateResponse{
|
||||||
|
Stream: ConvertKiroResponseToOpenAI,
|
||||||
|
NonStream: ConvertKiroResponseToOpenAINonStream,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,319 @@
|
|||||||
|
// Package chat_completions provides request translation from OpenAI to Kiro format.
|
||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertOpenAIRequestToKiro transforms an OpenAI Chat Completions API request into Kiro (Claude) format.
|
||||||
|
// Kiro uses Claude-compatible format internally, so we primarily pass through to Claude format.
|
||||||
|
// Supports tool calling: OpenAI tools -> Claude tools, tool_calls -> tool_use, tool messages -> tool_result.
|
||||||
|
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||||
|
rawJSON := bytes.Clone(inputRawJSON)
|
||||||
|
root := gjson.ParseBytes(rawJSON)
|
||||||
|
|
||||||
|
// Build Claude-compatible request
|
||||||
|
out := `{"model":"","max_tokens":32000,"messages":[]}`
|
||||||
|
|
||||||
|
// Set model
|
||||||
|
out, _ = sjson.Set(out, "model", modelName)
|
||||||
|
|
||||||
|
// Copy max_tokens if present
|
||||||
|
if v := root.Get("max_tokens"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "max_tokens", v.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy temperature if present
|
||||||
|
if v := root.Get("temperature"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "temperature", v.Float())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy top_p if present
|
||||||
|
if v := root.Get("top_p"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "top_p", v.Float())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI tools to Claude tools format
|
||||||
|
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||||
|
claudeTools := make([]interface{}, 0)
|
||||||
|
for _, tool := range tools.Array() {
|
||||||
|
if tool.Get("type").String() == "function" {
|
||||||
|
fn := tool.Get("function")
|
||||||
|
claudeTool := map[string]interface{}{
|
||||||
|
"name": fn.Get("name").String(),
|
||||||
|
"description": fn.Get("description").String(),
|
||||||
|
}
|
||||||
|
// Convert parameters to input_schema
|
||||||
|
if params := fn.Get("parameters"); params.Exists() {
|
||||||
|
claudeTool["input_schema"] = params.Value()
|
||||||
|
} else {
|
||||||
|
claudeTool["input_schema"] = map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeTools = append(claudeTools, claudeTool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(claudeTools) > 0 {
|
||||||
|
out, _ = sjson.Set(out, "tools", claudeTools)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
messages := root.Get("messages")
|
||||||
|
if messages.Exists() && messages.IsArray() {
|
||||||
|
claudeMessages := make([]interface{}, 0)
|
||||||
|
var systemPrompt string
|
||||||
|
|
||||||
|
// Track pending tool results to merge with next user message
|
||||||
|
var pendingToolResults []map[string]interface{}
|
||||||
|
|
||||||
|
for _, msg := range messages.Array() {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
content := msg.Get("content")
|
||||||
|
|
||||||
|
if role == "system" {
|
||||||
|
// Extract system message
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
systemPrompt += part.Get("text").String() + "\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
systemPrompt = content.String()
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if role == "tool" {
|
||||||
|
// OpenAI tool message -> Claude tool_result content block
|
||||||
|
toolCallID := msg.Get("tool_call_id").String()
|
||||||
|
toolContent := content.String()
|
||||||
|
|
||||||
|
toolResult := map[string]interface{}{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": toolCallID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle content - can be string or structured
|
||||||
|
if content.IsArray() {
|
||||||
|
contentParts := make([]interface{}, 0)
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": part.Get("text").String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolResult["content"] = contentParts
|
||||||
|
} else {
|
||||||
|
toolResult["content"] = toolContent
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingToolResults = append(pendingToolResults, toolResult)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeMsg := map[string]interface{}{
|
||||||
|
"role": role,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle assistant messages with tool_calls
|
||||||
|
if role == "assistant" && msg.Get("tool_calls").Exists() {
|
||||||
|
contentParts := make([]interface{}, 0)
|
||||||
|
|
||||||
|
// Add text content if present
|
||||||
|
if content.Exists() && content.String() != "" {
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tool_calls to tool_use blocks
|
||||||
|
for _, toolCall := range msg.Get("tool_calls").Array() {
|
||||||
|
toolUseID := toolCall.Get("id").String()
|
||||||
|
fnName := toolCall.Get("function.name").String()
|
||||||
|
fnArgs := toolCall.Get("function.arguments").String()
|
||||||
|
|
||||||
|
// Parse arguments JSON
|
||||||
|
var argsMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(fnArgs), &argsMap); err != nil {
|
||||||
|
argsMap = map[string]interface{}{"raw": fnArgs}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolUseID,
|
||||||
|
"name": fnName,
|
||||||
|
"input": argsMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeMsg["content"] = contentParts
|
||||||
|
claudeMessages = append(claudeMessages, claudeMsg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle user messages - may need to include pending tool results
|
||||||
|
if role == "user" && len(pendingToolResults) > 0 {
|
||||||
|
contentParts := make([]interface{}, 0)
|
||||||
|
|
||||||
|
// Add pending tool results first
|
||||||
|
for _, tr := range pendingToolResults {
|
||||||
|
contentParts = append(contentParts, tr)
|
||||||
|
}
|
||||||
|
pendingToolResults = nil
|
||||||
|
|
||||||
|
// Add user content
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
if partType == "text" {
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": part.Get("text").String(),
|
||||||
|
})
|
||||||
|
} else if partType == "image_url" {
|
||||||
|
imageURL := part.Get("image_url.url").String()
|
||||||
|
|
||||||
|
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||||
|
if strings.HasPrefix(imageURL, "data:") {
|
||||||
|
// Parse data URL format
|
||||||
|
// Format: data:image/png;base64,xxxxx
|
||||||
|
commaIdx := strings.Index(imageURL, ",")
|
||||||
|
if commaIdx != -1 {
|
||||||
|
// Extract media_type (e.g., "image/png")
|
||||||
|
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||||
|
mediaType := header
|
||||||
|
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||||
|
mediaType = header[:semiIdx]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base64 data
|
||||||
|
base64Data := imageURL[commaIdx+1:]
|
||||||
|
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]interface{}{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": mediaType,
|
||||||
|
"data": base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular URL format - keep original logic
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]interface{}{
|
||||||
|
"type": "url",
|
||||||
|
"url": imageURL,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if content.String() != "" {
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeMsg["content"] = contentParts
|
||||||
|
claudeMessages = append(claudeMessages, claudeMsg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular content
|
||||||
|
if content.IsArray() {
|
||||||
|
contentParts := make([]interface{}, 0)
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
if partType == "text" {
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": part.Get("text").String(),
|
||||||
|
})
|
||||||
|
} else if partType == "image_url" {
|
||||||
|
imageURL := part.Get("image_url.url").String()
|
||||||
|
|
||||||
|
// Check if it's base64 format (data:image/png;base64,xxxxx)
|
||||||
|
if strings.HasPrefix(imageURL, "data:") {
|
||||||
|
// Parse data URL format
|
||||||
|
// Format: data:image/png;base64,xxxxx
|
||||||
|
commaIdx := strings.Index(imageURL, ",")
|
||||||
|
if commaIdx != -1 {
|
||||||
|
// Extract media_type (e.g., "image/png")
|
||||||
|
header := imageURL[5:commaIdx] // Remove "data:" prefix
|
||||||
|
mediaType := header
|
||||||
|
if semiIdx := strings.Index(header, ";"); semiIdx != -1 {
|
||||||
|
mediaType = header[:semiIdx]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base64 data
|
||||||
|
base64Data := imageURL[commaIdx+1:]
|
||||||
|
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]interface{}{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": mediaType,
|
||||||
|
"data": base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular URL format - keep original logic
|
||||||
|
contentParts = append(contentParts, map[string]interface{}{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]interface{}{
|
||||||
|
"type": "url",
|
||||||
|
"url": imageURL,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeMsg["content"] = contentParts
|
||||||
|
} else {
|
||||||
|
claudeMsg["content"] = content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeMessages = append(claudeMessages, claudeMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are pending tool results without a following user message,
|
||||||
|
// create a user message with just the tool results
|
||||||
|
if len(pendingToolResults) > 0 {
|
||||||
|
contentParts := make([]interface{}, 0)
|
||||||
|
for _, tr := range pendingToolResults {
|
||||||
|
contentParts = append(contentParts, tr)
|
||||||
|
}
|
||||||
|
claudeMessages = append(claudeMessages, map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": contentParts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ = sjson.Set(out, "messages", claudeMessages)
|
||||||
|
|
||||||
|
if systemPrompt != "" {
|
||||||
|
out, _ = sjson.Set(out, "system", systemPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stream
|
||||||
|
out, _ = sjson.Set(out, "stream", stream)
|
||||||
|
|
||||||
|
return []byte(out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,360 @@
|
|||||||
|
// Package chat_completions provides response translation from Kiro to OpenAI format.
|
||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertKiroResponseToOpenAI converts Kiro streaming response to OpenAI SSE format.
|
||||||
|
// Handles Claude SSE events: content_block_start, content_block_delta, input_json_delta,
|
||||||
|
// content_block_stop, message_delta, and message_stop.
|
||||||
|
// Input may be in SSE format: "event: xxx\ndata: {...}" or raw JSON.
|
||||||
|
func ConvertKiroResponseToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||||
|
raw := string(rawResponse)
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
// Handle SSE format: extract JSON from "data: " lines
|
||||||
|
// Input format: "event: message_start\ndata: {...}"
|
||||||
|
lines := strings.Split(raw, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(line, "data: ") {
|
||||||
|
jsonPart := strings.TrimPrefix(line, "data: ")
|
||||||
|
chunks := convertClaudeEventToOpenAI(jsonPart, model)
|
||||||
|
results = append(results, chunks...)
|
||||||
|
} else if strings.HasPrefix(line, "{") {
|
||||||
|
// Raw JSON (backward compatibility)
|
||||||
|
chunks := convertClaudeEventToOpenAI(line, model)
|
||||||
|
results = append(results, chunks...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertClaudeEventToOpenAI converts a single Claude JSON event to OpenAI format
|
||||||
|
func convertClaudeEventToOpenAI(jsonStr string, model string) []string {
|
||||||
|
root := gjson.Parse(jsonStr)
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
eventType := root.Get("type").String()
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
|
case "message_start":
|
||||||
|
// Initial message event - emit initial chunk with role
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
return results
|
||||||
|
|
||||||
|
case "content_block_start":
|
||||||
|
// Start of a content block (text or tool_use)
|
||||||
|
blockType := root.Get("content_block.type").String()
|
||||||
|
index := int(root.Get("index").Int())
|
||||||
|
|
||||||
|
if blockType == "tool_use" {
|
||||||
|
// Start of tool_use block
|
||||||
|
toolUseID := root.Get("content_block.id").String()
|
||||||
|
toolName := root.Get("content_block.name").String()
|
||||||
|
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": index,
|
||||||
|
"id": toolUseID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": toolName,
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
case "content_block_delta":
|
||||||
|
index := int(root.Get("index").Int())
|
||||||
|
deltaType := root.Get("delta.type").String()
|
||||||
|
|
||||||
|
if deltaType == "text_delta" {
|
||||||
|
// Text content delta
|
||||||
|
contentDelta := root.Get("delta.text").String()
|
||||||
|
if contentDelta != "" {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"content": contentDelta,
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
} else if deltaType == "input_json_delta" {
|
||||||
|
// Tool input delta (streaming arguments)
|
||||||
|
partialJSON := root.Get("delta.partial_json").String()
|
||||||
|
if partialJSON != "" {
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": index,
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"arguments": partialJSON,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
// End of content block - no output needed for OpenAI format
|
||||||
|
return results
|
||||||
|
|
||||||
|
case "message_delta":
|
||||||
|
// Final message delta with stop_reason
|
||||||
|
stopReason := root.Get("delta.stop_reason").String()
|
||||||
|
if stopReason != "" {
|
||||||
|
finishReason := "stop"
|
||||||
|
if stopReason == "tool_use" {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
} else if stopReason == "end_turn" {
|
||||||
|
finishReason = "stop"
|
||||||
|
} else if stopReason == "max_tokens" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{},
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
case "message_stop":
|
||||||
|
// End of message - could emit [DONE] marker
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: handle raw content for backward compatibility
|
||||||
|
var contentDelta string
|
||||||
|
if delta := root.Get("delta.text"); delta.Exists() {
|
||||||
|
contentDelta = delta.String()
|
||||||
|
} else if content := root.Get("content"); content.Exists() && root.Get("type").String() == "" {
|
||||||
|
contentDelta = content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentDelta != "" {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"content": contentDelta,
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool_use content blocks (Claude format) - fallback
|
||||||
|
toolUses := root.Get("delta.tool_use")
|
||||||
|
if !toolUses.Exists() {
|
||||||
|
toolUses = root.Get("tool_use")
|
||||||
|
}
|
||||||
|
if toolUses.Exists() && toolUses.IsObject() {
|
||||||
|
inputJSON := toolUses.Get("input").String()
|
||||||
|
if inputJSON == "" {
|
||||||
|
if inputObj := toolUses.Get("input"); inputObj.Exists() {
|
||||||
|
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||||
|
inputJSON = string(inputBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"id": toolUses.Get("id").String(),
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": toolUses.Get("name").String(),
|
||||||
|
"arguments": inputJSON,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
results = append(results, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertKiroResponseToOpenAINonStream converts Kiro non-streaming response to OpenAI format.
|
||||||
|
func ConvertKiroResponseToOpenAINonStream(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||||
|
root := gjson.ParseBytes(rawResponse)
|
||||||
|
|
||||||
|
var content string
|
||||||
|
var toolCalls []map[string]interface{}
|
||||||
|
|
||||||
|
contentArray := root.Get("content")
|
||||||
|
if contentArray.IsArray() {
|
||||||
|
for _, item := range contentArray.Array() {
|
||||||
|
itemType := item.Get("type").String()
|
||||||
|
if itemType == "text" {
|
||||||
|
content += item.Get("text").String()
|
||||||
|
} else if itemType == "tool_use" {
|
||||||
|
// Convert Claude tool_use to OpenAI tool_calls format
|
||||||
|
inputJSON := item.Get("input").String()
|
||||||
|
if inputJSON == "" {
|
||||||
|
// If input is an object, marshal it
|
||||||
|
if inputObj := item.Get("input"); inputObj.Exists() {
|
||||||
|
inputBytes, _ := json.Marshal(inputObj.Value())
|
||||||
|
inputJSON = string(inputBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"id": item.Get("id").String(),
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": item.Get("name").String(),
|
||||||
|
"arguments": inputJSON,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content = root.Get("content").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := root.Get("usage.input_tokens").Int()
|
||||||
|
outputTokens := root.Get("usage.output_tokens").Int()
|
||||||
|
|
||||||
|
message := map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool_calls if present
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
message["tool_calls"] = toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"prompt_tokens": inputTokens,
|
||||||
|
"completion_tokens": outputTokens,
|
||||||
|
"total_tokens": inputTokens + outputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ package claude
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -242,11 +243,12 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
|||||||
|
|
||||||
switch partType {
|
switch partType {
|
||||||
case "text":
|
case "text":
|
||||||
if !part.Get("text").Exists() {
|
text := part.Get("text").String()
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
textContent := `{"type":"text","text":""}`
|
textContent := `{"type":"text","text":""}`
|
||||||
textContent, _ = sjson.Set(textContent, "text", part.Get("text").String())
|
textContent, _ = sjson.Set(textContent, "text", text)
|
||||||
return textContent, true
|
return textContent, true
|
||||||
|
|
||||||
case "image":
|
case "image":
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -41,6 +42,9 @@ type oaiToResponsesState struct {
|
|||||||
UsageSeen bool
|
UsageSeen bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||||
|
var responseIDCounter uint64
|
||||||
|
|
||||||
func emitRespEvent(event string, payload string) string {
|
func emitRespEvent(event string, payload string) string {
|
||||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||||
}
|
}
|
||||||
@@ -590,7 +594,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
|
|||||||
// id: use provider id if present, otherwise synthesize
|
// id: use provider id if present, otherwise synthesize
|
||||||
id := root.Get("id").String()
|
id := root.Get("id").String()
|
||||||
if id == "" {
|
if id == "" {
|
||||||
id = fmt.Sprintf("resp_%x", time.Now().UnixNano())
|
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
|
||||||
}
|
}
|
||||||
resp, _ = sjson.Set(resp, "id", id)
|
resp, _ = sjson.Set(resp, "id", id)
|
||||||
|
|
||||||
|
|||||||
@@ -207,6 +207,47 @@ func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) {
|
|||||||
return budgetPtr, includePtr, matched
|
return budgetPtr, includePtr, matched
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// modelsWithDefaultThinking lists models that should have thinking enabled by default
|
||||||
|
// when no explicit thinkingConfig is provided.
|
||||||
|
var modelsWithDefaultThinking = map[string]bool{
|
||||||
|
"gemini-3-pro-preview": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
|
||||||
|
func ModelHasDefaultThinking(model string) bool {
|
||||||
|
return modelsWithDefaultThinking[model]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it.
|
||||||
|
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||||
|
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||||
|
func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||||
|
if !ModelHasDefaultThinking(model) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it.
|
||||||
|
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||||
|
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||||
|
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||||
|
if !ModelHasDefaultThinking(model) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
// StripThinkingConfigIfUnsupported removes thinkingConfig from the request body
|
// StripThinkingConfigIfUnsupported removes thinkingConfig from the request body
|
||||||
// when the target model does not advertise Thinking capability. It cleans both
|
// when the target model does not advertise Thinking capability. It cleans both
|
||||||
// standard Gemini and Gemini CLI JSON envelopes. This acts as a final safety net
|
// standard Gemini and Gemini CLI JSON envelopes. This acts as a final safety net
|
||||||
@@ -223,6 +264,32 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
|
|||||||
return updated
|
return updated
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
|
||||||
|
// request body (generationConfig.thinkingConfig.thinkingBudget path).
|
||||||
|
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||||
|
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
|
||||||
|
budget := gjson.GetBytes(body, budgetPath)
|
||||||
|
if !budget.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||||
|
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
|
||||||
|
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
|
||||||
|
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||||
|
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||||
|
budget := gjson.GetBytes(body, budgetPath)
|
||||||
|
if !budget.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||||
|
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||||
// and converts it to "thinkingBudget".
|
// and converts it to "thinkingBudget".
|
||||||
// "high" -> 32768
|
// "high" -> 32768
|
||||||
|
|||||||
@@ -79,6 +79,15 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
|||||||
return finalJson, nil
|
return finalJson, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteKey(jsonStr, keyName string) string {
|
||||||
|
paths := make([]string, 0)
|
||||||
|
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
||||||
|
for _, p := range paths {
|
||||||
|
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||||
|
}
|
||||||
|
return jsonStr
|
||||||
|
}
|
||||||
|
|
||||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||||
// double-quoted strings with proper escaping.
|
// double-quoted strings with proper escaping.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
@@ -176,6 +177,9 @@ func (w *Watcher) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
log.Debugf("watching auth directory: %s", w.authDir)
|
log.Debugf("watching auth directory: %s", w.authDir)
|
||||||
|
|
||||||
|
// Watch Kiro IDE token file directory for automatic token updates
|
||||||
|
w.watchKiroIDETokenFile()
|
||||||
|
|
||||||
// Start the event processing goroutine
|
// Start the event processing goroutine
|
||||||
go w.processEvents(ctx)
|
go w.processEvents(ctx)
|
||||||
|
|
||||||
@@ -184,6 +188,31 @@ func (w *Watcher) Start(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher.
|
||||||
|
// This enables automatic detection of token updates from Kiro IDE.
|
||||||
|
func (w *Watcher) watchKiroIDETokenFile() {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kiro IDE stores tokens in ~/.aws/sso/cache/
|
||||||
|
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||||
|
|
||||||
|
// Check if directory exists
|
||||||
|
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
|
||||||
|
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
|
||||||
|
log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the file watcher
|
// Stop stops the file watcher
|
||||||
func (w *Watcher) Stop() error {
|
func (w *Watcher) Stop() error {
|
||||||
w.stopDispatch()
|
w.stopDispatch()
|
||||||
@@ -744,10 +773,20 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
|||||||
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
|
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
|
||||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||||
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
|
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
|
||||||
if !isConfigEvent && !isAuthJSON {
|
|
||||||
|
// Check for Kiro IDE token file changes
|
||||||
|
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||||
|
|
||||||
|
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
||||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Kiro IDE token file changes
|
||||||
|
if isKiroIDEToken {
|
||||||
|
w.handleKiroIDETokenChange(event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||||
@@ -805,6 +844,51 @@ func (w *Watcher) scheduleConfigReload() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isKiroIDETokenFile checks if the given path is the Kiro IDE token file.
|
||||||
|
func (w *Watcher) isKiroIDETokenFile(path string) bool {
|
||||||
|
// Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/
|
||||||
|
// Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes)
|
||||||
|
normalized := filepath.ToSlash(path)
|
||||||
|
return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleKiroIDETokenChange processes changes to the Kiro IDE token file.
|
||||||
|
// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth.
|
||||||
|
func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
||||||
|
log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
|
||||||
|
|
||||||
|
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||||
|
// Token file removed - wait briefly for potential atomic replace
|
||||||
|
time.Sleep(replaceCheckDelay)
|
||||||
|
if _, statErr := os.Stat(event.Name); statErr != nil {
|
||||||
|
log.Debugf("Kiro IDE token file removed: %s", event.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to load the updated token
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
|
||||||
|
|
||||||
|
// Trigger auth state refresh to pick up the new token
|
||||||
|
w.refreshAuthState()
|
||||||
|
|
||||||
|
// Notify callback if set
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
cfg := w.config
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if w.reloadCallback != nil && cfg != nil {
|
||||||
|
log.Debugf("triggering server update callback after Kiro IDE token change")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *Watcher) reloadConfigIfChanged() {
|
func (w *Watcher) reloadConfigIfChanged() {
|
||||||
data, err := os.ReadFile(w.configPath)
|
data, err := os.ReadFile(w.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1181,6 +1265,82 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||||
out = append(out, a)
|
out = append(out, a)
|
||||||
}
|
}
|
||||||
|
// Kiro (AWS CodeWhisperer) -> synthesize auths
|
||||||
|
var kAuth *kiroauth.KiroAuth
|
||||||
|
if len(cfg.KiroKey) > 0 {
|
||||||
|
kAuth = kiroauth.NewKiroAuth(cfg)
|
||||||
|
}
|
||||||
|
for i := range cfg.KiroKey {
|
||||||
|
kk := cfg.KiroKey[i]
|
||||||
|
var accessToken, profileArn, refreshToken string
|
||||||
|
|
||||||
|
// Try to load from token file first
|
||||||
|
if kk.TokenFile != "" && kAuth != nil {
|
||||||
|
tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err)
|
||||||
|
} else {
|
||||||
|
accessToken = tokenData.AccessToken
|
||||||
|
profileArn = tokenData.ProfileArn
|
||||||
|
refreshToken = tokenData.RefreshToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with direct config values if provided
|
||||||
|
if kk.AccessToken != "" {
|
||||||
|
accessToken = kk.AccessToken
|
||||||
|
}
|
||||||
|
if kk.ProfileArn != "" {
|
||||||
|
profileArn = kk.ProfileArn
|
||||||
|
}
|
||||||
|
if kk.RefreshToken != "" {
|
||||||
|
refreshToken = kk.RefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessToken == "" {
|
||||||
|
log.Warnf("kiro config[%d] missing access_token, skipping", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// profileArn is optional for AWS Builder ID users
|
||||||
|
id, token := idGen.next("kiro:token", accessToken, profileArn)
|
||||||
|
attrs := map[string]string{
|
||||||
|
"source": fmt.Sprintf("config:kiro[%s]", token),
|
||||||
|
"access_token": accessToken,
|
||||||
|
}
|
||||||
|
if profileArn != "" {
|
||||||
|
attrs["profile_arn"] = profileArn
|
||||||
|
}
|
||||||
|
if kk.Region != "" {
|
||||||
|
attrs["region"] = kk.Region
|
||||||
|
}
|
||||||
|
if kk.AgentTaskType != "" {
|
||||||
|
attrs["agent_task_type"] = kk.AgentTaskType
|
||||||
|
}
|
||||||
|
if refreshToken != "" {
|
||||||
|
attrs["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
proxyURL := strings.TrimSpace(kk.ProxyURL)
|
||||||
|
a := &coreauth.Auth{
|
||||||
|
ID: id,
|
||||||
|
Provider: "kiro",
|
||||||
|
Label: "kiro-token",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Attributes: attrs,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshToken != "" {
|
||||||
|
if a.Metadata == nil {
|
||||||
|
a.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
a.Metadata["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
for i := range cfg.OpenAICompatibility {
|
for i := range cfg.OpenAICompatibility {
|
||||||
compat := &cfg.OpenAICompatibility[i]
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||||
@@ -1287,7 +1447,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
||||||
entries, _ := os.ReadDir(w.authDir)
|
log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir)
|
||||||
|
entries, readErr := os.ReadDir(w.authDir)
|
||||||
|
if readErr != nil {
|
||||||
|
log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr)
|
||||||
|
}
|
||||||
|
log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries))
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if e.IsDir() {
|
if e.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -1306,9 +1471,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t, _ := metadata["type"].(string)
|
t, _ := metadata["type"].(string)
|
||||||
|
|
||||||
|
// Detect Kiro auth files by auth_method field (they don't have "type" field)
|
||||||
if t == "" {
|
if t == "" {
|
||||||
|
if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" {
|
||||||
|
t = "kiro"
|
||||||
|
log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if t == "" {
|
||||||
|
log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t)
|
||||||
provider := strings.ToLower(t)
|
provider := strings.ToLower(t)
|
||||||
if provider == "gemini" {
|
if provider == "gemini" {
|
||||||
provider = "gemini-cli"
|
provider = "gemini-cli"
|
||||||
@@ -1317,6 +1493,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if email, _ := metadata["email"].(string); email != "" {
|
if email, _ := metadata["email"].(string); email != "" {
|
||||||
label = email
|
label = email
|
||||||
}
|
}
|
||||||
|
// For Kiro, use provider field as label if available
|
||||||
|
if provider == "kiro" {
|
||||||
|
if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" {
|
||||||
|
label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider))
|
||||||
|
}
|
||||||
|
}
|
||||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||||
id := full
|
id := full
|
||||||
if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" {
|
if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" {
|
||||||
@@ -1342,6 +1524,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
|
// Set NextRefreshAfter for Kiro auth based on expires_at
|
||||||
|
if provider == "kiro" {
|
||||||
|
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||||
|
if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil {
|
||||||
|
// Refresh 30 minutes before expiry
|
||||||
|
a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||||
if provider == "gemini-cli" {
|
if provider == "gemini-cli" {
|
||||||
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||||
|
|||||||
@@ -48,8 +48,24 @@ func (h *GeminiAPIHandler) Models() []map[string]any {
|
|||||||
// GeminiModels handles the Gemini models listing endpoint.
|
// GeminiModels handles the Gemini models listing endpoint.
|
||||||
// It returns a JSON response containing available Gemini models and their specifications.
|
// It returns a JSON response containing available Gemini models and their specifications.
|
||||||
func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
|
func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
|
||||||
|
rawModels := h.Models()
|
||||||
|
normalizedModels := make([]map[string]any, 0, len(rawModels))
|
||||||
|
defaultMethods := []string{"generateContent"}
|
||||||
|
for _, model := range rawModels {
|
||||||
|
normalizedModel := make(map[string]any, len(model))
|
||||||
|
for k, v := range model {
|
||||||
|
normalizedModel[k] = v
|
||||||
|
}
|
||||||
|
if name, ok := normalizedModel["name"].(string); ok && name != "" && !strings.HasPrefix(name, "models/") {
|
||||||
|
normalizedModel["name"] = "models/" + name
|
||||||
|
}
|
||||||
|
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
||||||
|
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
||||||
|
}
|
||||||
|
normalizedModels = append(normalizedModels, normalizedModel)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"models": h.Models(),
|
"models": normalizedModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
@@ -50,7 +51,25 @@ type BaseAPIHandler struct {
|
|||||||
Cfg *config.SDKConfig
|
Cfg *config.SDKConfig
|
||||||
|
|
||||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||||
OpenAICompatProviders []string
|
openAICompatProviders []string
|
||||||
|
openAICompatMutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAICompatProviders safely returns a copy of the provider names
|
||||||
|
func (h *BaseAPIHandler) GetOpenAICompatProviders() []string {
|
||||||
|
h.openAICompatMutex.RLock()
|
||||||
|
defer h.openAICompatMutex.RUnlock()
|
||||||
|
result := make([]string, len(h.openAICompatProviders))
|
||||||
|
copy(result, h.openAICompatProviders)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOpenAICompatProviders safely sets the provider names
|
||||||
|
func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||||
|
h.openAICompatMutex.Lock()
|
||||||
|
defer h.openAICompatMutex.Unlock()
|
||||||
|
h.openAICompatProviders = make([]string, len(providers))
|
||||||
|
copy(h.openAICompatProviders, providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||||
@@ -63,11 +82,12 @@ type BaseAPIHandler struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *BaseAPIHandler: A new API handlers instance
|
// - *BaseAPIHandler: A new API handlers instance
|
||||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||||
return &BaseAPIHandler{
|
h := &BaseAPIHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthManager: authManager,
|
AuthManager: authManager,
|
||||||
OpenAICompatProviders: openAICompatProviders,
|
|
||||||
}
|
}
|
||||||
|
h.SetOpenAICompatProviders(openAICompatProviders)
|
||||||
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the handlers' client list and configuration.
|
// UpdateClients updates the handlers' client list and configuration.
|
||||||
@@ -363,7 +383,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the provider is a configured openai-compatibility provider
|
// Check if the provider is a configured openai-compatibility provider
|
||||||
for _, pName := range h.OpenAICompatProviders {
|
for _, pName := range h.GetOpenAICompatProviders() {
|
||||||
if pName == providerPart {
|
if pName == providerPart {
|
||||||
return providerPart, modelPart, true
|
return providerPart, modelPart, true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -127,6 +128,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
|
||||||
|
projectID := ""
|
||||||
|
if tokenResp.AccessToken != "" {
|
||||||
|
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||||
|
if errProject != nil {
|
||||||
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
|
} else {
|
||||||
|
projectID = fetchedProjectID
|
||||||
|
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"type": "antigravity",
|
"type": "antigravity",
|
||||||
@@ -139,6 +152,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
if email != "" {
|
if email != "" {
|
||||||
metadata["email"] = email
|
metadata["email"] = email
|
||||||
}
|
}
|
||||||
|
if projectID != "" {
|
||||||
|
metadata["project_id"] = projectID
|
||||||
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := sanitizeAntigravityFileName(email)
|
||||||
label := email
|
label := email
|
||||||
@@ -147,6 +163,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Antigravity authentication successful")
|
fmt.Println("Antigravity authentication successful")
|
||||||
|
if projectID != "" {
|
||||||
|
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||||
|
}
|
||||||
return &coreauth.Auth{
|
return &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "antigravity",
|
Provider: "antigravity",
|
||||||
@@ -291,3 +310,89 @@ func sanitizeAntigravityFileName(email string) string {
|
|||||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
replacer := strings.NewReplacer("@", "_", ".", "_")
|
||||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Antigravity API constants for project discovery
|
||||||
|
const (
|
||||||
|
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
|
antigravityAPIVersion = "v1internal"
|
||||||
|
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||||
|
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||||
|
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||||
|
)
|
||||||
|
|
||||||
|
// FetchAntigravityProjectID exposes project discovery for external callers.
|
||||||
|
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||||
|
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
|
||||||
|
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
|
||||||
|
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||||
|
// Call loadCodeAssist to get the project
|
||||||
|
loadReqBody := map[string]any{
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "IDE_UNSPECIFIED",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||||
|
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadResp map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract projectID from response
|
||||||
|
projectID := ""
|
||||||
|
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||||
|
if id, okID := projectMap["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID == "" {
|
||||||
|
return "", fmt.Errorf("no cloudaicompanionProject in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|||||||
357
sdk/auth/kiro.go
Normal file
357
sdk/auth/kiro.go
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||||
|
// Returns account name if provided, otherwise profile ARN ID.
|
||||||
|
// All extracted values are sanitized to prevent path injection attacks.
|
||||||
|
func extractKiroIdentifier(accountName, profileArn string) string {
|
||||||
|
// Priority 1: Use account name if provided
|
||||||
|
if accountName != "" {
|
||||||
|
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||||
|
if profileArn != "" {
|
||||||
|
parts := strings.Split(profileArn, "/")
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
// Sanitize the ARN component to prevent path traversal
|
||||||
|
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: timestamp
|
||||||
|
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||||
|
type KiroAuthenticator struct{}
|
||||||
|
|
||||||
|
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||||
|
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||||
|
return &KiroAuthenticator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returns the provider key for the authenticator.
|
||||||
|
func (a *KiroAuthenticator) Provider() string {
|
||||||
|
return "kiro"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||||
|
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||||
|
d := 30 * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login performs OAuth login for Kiro with AWS Builder ID.
|
||||||
|
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use AWS Builder ID device code flow
|
||||||
|
tokenData, err := oauth.LoginWithBuilderID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-aws",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"client_id": tokenData.ClientID,
|
||||||
|
"client_secret": tokenData.ClientSecret,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "aws-builder-id",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use Google OAuth flow with protocol handler
|
||||||
|
tokenData, err := oauth.LoginWithGoogle(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("google login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-google",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "google-oauth",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use GitHub OAuth flow with protocol handler
|
||||||
|
tokenData, err := oauth.LoginWithGitHub(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-github",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "github-oauth",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||||
|
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract email from JWT if not already set (for imported tokens)
|
||||||
|
if tokenData.Email == "" {
|
||||||
|
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||||
|
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||||
|
if provider == "" {
|
||||||
|
provider = "imported" // Fallback for legacy tokens without provider
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: fmt.Sprintf("kiro-%s", provider),
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "kiro-ide-import",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display the email if extracted
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||||
|
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return nil, fmt.Errorf("invalid auth record")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||||
|
if !ok || refreshToken == "" {
|
||||||
|
return nil, fmt.Errorf("refresh token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
|
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
|
||||||
|
var tokenData *kiroauth.KiroTokenData
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
|
} else {
|
||||||
|
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone auth to avoid mutating the input parameter
|
||||||
|
updated := auth.Clone()
|
||||||
|
now := time.Now()
|
||||||
|
updated.UpdatedAt = now
|
||||||
|
updated.LastRefreshedAt = now
|
||||||
|
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||||
|
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||||
|
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||||
|
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||||
|
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||||
|
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
|
|||||||
}
|
}
|
||||||
return record, savedPath, nil
|
return record, savedPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAuth persists an auth record directly without going through the login flow.
|
||||||
|
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
|
||||||
|
if m.store == nil {
|
||||||
|
return "", fmt.Errorf("no store configured")
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
|
||||||
|
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m.store.Save(context.Background(), record)
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ func init() {
|
|||||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||||
|
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -379,6 +379,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||||
case "iflow":
|
case "iflow":
|
||||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||||
|
case "kiro":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||||
default:
|
default:
|
||||||
@@ -500,7 +502,7 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
|
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
|
||||||
|
|
||||||
if s.hooks.OnAfterStart != nil {
|
if s.hooks.OnAfterStart != nil {
|
||||||
s.hooks.OnAfterStart(s)
|
s.hooks.OnAfterStart(s)
|
||||||
@@ -725,6 +727,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
models = registry.GetGitHubCopilotModels()
|
models = registry.GetGitHubCopilotModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
case "kiro":
|
||||||
|
models = registry.GetKiroModels()
|
||||||
|
models = applyExcludedModels(models, excluded)
|
||||||
default:
|
default:
|
||||||
// Handle OpenAI-compatibility providers by name using config
|
// Handle OpenAI-compatibility providers by name using config
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
@@ -783,7 +788,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
OwnedBy: compat.Name,
|
OwnedBy: compat.Name,
|
||||||
Type: "openai-compatibility",
|
Type: "openai-compatibility",
|
||||||
DisplayName: m.Name,
|
DisplayName: modelID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Register and return
|
// Register and return
|
||||||
|
|||||||
827
test/amp_management_test.go
Normal file
827
test/amp_management_test.go
Normal file
@@ -0,0 +1,827 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAmpTestHandler creates a test handler with default ampcode configuration.
|
||||||
|
func newAmpTestHandler(t *testing.T) (*management.Handler, string) {
|
||||||
|
t.Helper()
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
UpstreamURL: "https://example.com",
|
||||||
|
UpstreamAPIKey: "test-api-key-12345",
|
||||||
|
RestrictManagementToLocalhost: true,
|
||||||
|
ForceModelMappings: false,
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "gemini-pro"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := management.NewHandler(cfg, configPath, nil)
|
||||||
|
return h, configPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupAmpRouter creates a test router with all ampcode management endpoints.
|
||||||
|
func setupAmpRouter(h *management.Handler) *gin.Engine {
|
||||||
|
r := gin.New()
|
||||||
|
mgmt := r.Group("/v0/management")
|
||||||
|
{
|
||||||
|
mgmt.GET("/ampcode", h.GetAmpCode)
|
||||||
|
mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL)
|
||||||
|
mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
||||||
|
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
||||||
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
||||||
|
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
||||||
|
mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings)
|
||||||
|
mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings)
|
||||||
|
mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings)
|
||||||
|
mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings)
|
||||||
|
mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config.
|
||||||
|
func TestGetAmpCode(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]config.AmpCode
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ampcode := resp["ampcode"]
|
||||||
|
if ampcode.UpstreamURL != "https://example.com" {
|
||||||
|
t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL)
|
||||||
|
}
|
||||||
|
if len(ampcode.ModelMappings) != 1 {
|
||||||
|
t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL.
|
||||||
|
func TestGetAmpUpstreamURL(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["upstream-url"] != "https://example.com" {
|
||||||
|
t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL.
|
||||||
|
func TestPutAmpUpstreamURL(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": "https://new-upstream.com"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL.
|
||||||
|
func TestDeleteAmpUpstreamURL(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key.
|
||||||
|
func TestGetAmpUpstreamAPIKey(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := resp["upstream-api-key"].(string)
|
||||||
|
if key != "test-api-key-12345" {
|
||||||
|
t.Errorf("expected key %q, got %q", "test-api-key-12345", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key.
|
||||||
|
func TestPutAmpUpstreamAPIKey(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": "new-secret-key"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
||||||
|
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting.
|
||||||
|
func TestGetAmpRestrictManagementToLocalhost(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]bool
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["restrict-management-to-localhost"] != true {
|
||||||
|
t.Error("expected restrict-management-to-localhost to be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting.
|
||||||
|
func TestPutAmpRestrictManagementToLocalhost(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": false}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings.
|
||||||
|
func TestGetAmpModelMappings(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := resp["model-mappings"]
|
||||||
|
if len(mappings) != 1 {
|
||||||
|
t.Fatalf("expected 1 mapping, got %d", len(mappings))
|
||||||
|
}
|
||||||
|
if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" {
|
||||||
|
t.Errorf("unexpected mapping: %+v", mappings[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings.
|
||||||
|
func TestPutAmpModelMappings(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones.
|
||||||
|
func TestPatchAmpModelMappings(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field.
|
||||||
|
func TestDeleteAmpModelMappings_Specific(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": ["gpt-4"]}`
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings.
|
||||||
|
func TestDeleteAmpModelMappings_All(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.
|
||||||
|
func TestGetAmpForceModelMappings(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]bool
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["force-model-mappings"] != false {
|
||||||
|
t.Error("expected force-model-mappings to be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting.
|
||||||
|
func TestPutAmpForceModelMappings(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted.
|
||||||
|
func TestPutAmpModelMappings_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := resp["model-mappings"]
|
||||||
|
if len(mappings) != 3 {
|
||||||
|
t.Fatalf("expected 3 mappings, got %d", len(mappings))
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"}
|
||||||
|
for _, m := range mappings {
|
||||||
|
if expected[m.From] != m.To {
|
||||||
|
t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly.
|
||||||
|
func TestPatchAmpModelMappings_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PATCH failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := resp["model-mappings"]
|
||||||
|
if len(mappings) != 2 {
|
||||||
|
t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings))
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]string)
|
||||||
|
for _, m := range mappings {
|
||||||
|
found[m.From] = m.To
|
||||||
|
}
|
||||||
|
|
||||||
|
if found["gpt-4"] != "updated-target" {
|
||||||
|
t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"])
|
||||||
|
}
|
||||||
|
if found["new-model"] != "new-target" {
|
||||||
|
t.Errorf("new-model should map to new-target, got %q", found["new-model"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others.
|
||||||
|
func TestDeleteAmpModelMappings_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
delBody := `{"value": ["a", "c"]}`
|
||||||
|
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := resp["model-mappings"]
|
||||||
|
if len(mappings) != 1 {
|
||||||
|
t.Fatalf("expected 1 mapping remaining, got %d", len(mappings))
|
||||||
|
}
|
||||||
|
if mappings[0].From != "b" || mappings[0].To != "2" {
|
||||||
|
t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones.
|
||||||
|
func TestDeleteAmpModelMappings_NonExistent(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
delBody := `{"value": ["non-existent-model"]}`
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp["model-mappings"]) != 1 {
|
||||||
|
t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings.
|
||||||
|
func TestPutAmpModelMappings_Empty(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": []}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp["model-mappings"]) != 0 {
|
||||||
|
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state.
|
||||||
|
func TestPutAmpUpstreamURL_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": "https://new-api.example.com"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PUT failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["upstream-url"] != "https://new-api.example.com" {
|
||||||
|
t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL.
|
||||||
|
func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["upstream-url"] != "" {
|
||||||
|
t.Errorf("expected empty string, got %q", resp["upstream-url"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state.
|
||||||
|
func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": "new-secret-api-key-xyz"}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PUT failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["upstream-api-key"] != "new-secret-api-key-xyz" {
|
||||||
|
t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key.
|
||||||
|
func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("DELETE failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["upstream-api-key"] != "" {
|
||||||
|
t.Errorf("expected empty string, got %q", resp["upstream-api-key"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction.
|
||||||
|
func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": false}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PUT failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]bool
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["restrict-management-to-localhost"] != false {
|
||||||
|
t.Error("expected false after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting.
|
||||||
|
func TestPutAmpForceModelMappings_VerifyState(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value": true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("PUT failed: status %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string]bool
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp["force-model-mappings"] != true {
|
||||||
|
t.Error("expected true after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400.
|
||||||
|
func TestPutBoolField_EmptyObject(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET.
|
||||||
|
func TestComplexMappingsWorkflow(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}`
|
||||||
|
req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
delBody := `{"value": ["m1", "m3"]}`
|
||||||
|
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := resp["model-mappings"]
|
||||||
|
if len(mappings) != 3 {
|
||||||
|
t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings))
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"}
|
||||||
|
found := make(map[string]string)
|
||||||
|
for _, m := range mappings {
|
||||||
|
found[m.From] = m.To
|
||||||
|
}
|
||||||
|
|
||||||
|
for from, to := range expected {
|
||||||
|
if found[from] != to {
|
||||||
|
t.Errorf("mapping %s: expected %q, got %q", from, to, found[from])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNilHandlerGetAmpCode verifies handler works with empty config.
|
||||||
|
func TestNilHandlerGetAmpCode(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
h := management.NewHandler(cfg, "", nil)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config.
|
||||||
|
func TestEmptyConfigGetAmpModelMappings(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := management.NewHandler(cfg, configPath, nil)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string][]config.AmpModelMapping
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp["model-mappings"]) != 0 {
|
||||||
|
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user